In [None]:
!nvidia-smi -L

### configs

In [None]:
NAME = "Rifat" #your name here
PROJECT_NAME = "Fixmatch_Multioutput_Implementation"
MODEL_TYPE = "self_supervised"
ARCHITECTURE_NAME = "tf_efficientnet_lite0"


## Install dependency

In [None]:
!pip install -q timm 
!pip install -q --upgrade --force-reinstall --no-deps kaggle
!pip install -q opencv-python-headless==4.1.2.30  
!pip install -q  --upgrade wandb 

## **Download Kaggle Dataset and Dataframe create**

In [None]:
import os
import pandas as pd
import numpy as np
from imutils import paths
from google.colab import files
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split

In [None]:
uploaded = files.upload()

for fn in uploaded.keys():
  print('User uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn])))
  
# Then move kaggle.json into the folder where the API expects to find it.
!mkdir -p ~/.kaggle/ && mv kaggle.json ~/.kaggle/ && chmod 600 ~/.kaggle/kaggle.json

**Dataset Download and unzip**

In [None]:
!kaggle datasets download -d trolukovich/apparel-images-dataset
!mkdir dataset
!unzip -q apparel-images-dataset.zip -d ./dataset

**Dataframe Shuffle and Split**

In [None]:
df_apparel = pd.read_csv("/content/df_apparel_multilabel.csv")
df = shuffle(df_apparel)
ulb_dataframe , lb_dataframe = train_test_split(df,test_size = 0.2)
lb_dataframe = shuffle(lb_dataframe)
lb_dataframe, valid_dataframe = train_test_split(lb_dataframe,test_size = 0.2)
print(len(ulb_dataframe))
print(len(lb_dataframe))
print(len(valid_dataframe))

9108
1821
456


## **Useful imports**

In [None]:
import pandas as pd
import numpy as np
import cv2

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import torchvision
from torchvision import transforms

from tqdm.notebook import tqdm

from sklearn.metrics import accuracy_score
from timm.data.auto_augment import rand_augment_transform
import PIL
import matplotlib.pyplot as plt

import timm
import time 
from collections import OrderedDict

import wandb


**Variables**

In [None]:
# device is set to cuda if cuda is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
num_epochs = 10
total_class = len(df_apparel.columns)-1
threshold = 0.90
learning_rate = 0.001
lb_to_ulb_ratio = 4
batch_size_lb = 6
Color_logits = 6
save_path_checkpoints= "/content/model/ckpts"
os.makedirs(save_path_checkpoints, exist_ok=True)

**Dataset Class**

In [None]:
class ImageDataset(Dataset):
    def __init__(self, dataframe,weak_transform,strong_transform,normalize,is_lb=False):
        self.dataframe = dataframe
        self.weak_transform = weak_transform
        self.strong_transform = strong_transform
        self.normalize = normalize
        self.is_lb = is_lb
        self.all_image_names = self.dataframe[:]['ImagePath']
        self.all_image_label = self.dataframe.drop(['ImagePath'], axis=1)
        
    def __len__(self):
        return len(self.all_image_names)

    def __getitem__(self, index):
        img_path = os.path.join(self.all_image_names.iloc[index])
        image = PIL.Image.open(img_path)
        if self.is_lb == True:
          targets = np.array(self.all_image_label.iloc[index],dtype = np.float32)
          return self.normalize(image),targets
        else:  
          weak_image = self.weak_transform(image)
          strong_image = self.strong_transform(image)
          return self.normalize(weak_image),self.normalize(strong_image)
         

**Dataloader and Transfrom**

In [None]:
def fun_transfrom():
  strong_transform =  rand_augment_transform(
    config_str='rand-m9-mstd0.5', 
    hparams={}
              )

  weak_transform = transforms.Compose([transforms.RandomHorizontalFlip()])
      
      
  normalize = transforms.Compose([
            transforms.Resize((224,224)),                      
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])


  ulb_dataset = ImageDataset(
      ulb_dataframe,
      weak_transform,
      strong_transform,
      normalize,
      is_lb=False
      
  )
  lb_dataset = ImageDataset(
      lb_dataframe,
      weak_transform,
      strong_transform,
      normalize,
      is_lb=True
      
  )

  valid_dataset = ImageDataset(
      valid_dataframe,
      weak_transform,
      strong_transform,
      normalize,
      is_lb=True
      
  )
  
  dataloader_ulb_dataset = DataLoader(ulb_dataset, batch_size=batch_size_lb * lb_to_ulb_ratio, shuffle=True, num_workers=2)
  dataloader_lb_dataset = DataLoader(lb_dataset, batch_size=batch_size_lb, shuffle=True, num_workers=2)
  dataloader_valid_dataset = DataLoader(valid_dataset, batch_size=len(valid_dataframe), shuffle=False, num_workers=2)

  return dataloader_ulb_dataset,dataloader_lb_dataset,dataloader_valid_dataset


dataloader_ulb_dataset,dataloader_lb_dataset,dataloader_valid_dataset = fun_transfrom()

**Wandb**

In [None]:
!wandb login 

In [None]:
class WandbLogger():
    """
    This custom callback is used for logging training metrics to wandb for monitoring.
    
    """
    def __init__(self,project,entity,name,id,config,resume = "allow",):
      self.project = project
      self.entity = entity
      self.name = name
      self.id = id
      self.config = config
      self.resume = resume
      wandb.init(project = self.project,entity = self.entity,
                 name = self.name,id = self.id, 
                 config = self.config,resume = self.resume)

In [None]:
project= PROJECT_NAME
entity="rakib1521"


name = f"{PROJECT_NAME}_{ARCHITECTURE_NAME}" #same name for multiple run is allowed but same id is not allowed
id = f"{PROJECT_NAME}_{ARCHITECTURE_NAME}"

wandb_config = {"network":ARCHITECTURE_NAME,
                "epoch":num_epochs,
                "batch_size_lb": batch_size_lb,
                "lb_to_ulb_ratio":lb_to_ulb_ratio,
                "learning_rate": learning_rate,
                "probability_threshold": threshold,
                }
wandb_logger = WandbLogger(project,entity,name,id,wandb_config)    

# Model Define

In [None]:
model = timm.create_model('tf_efficientnet_lite0')

classifier = torch.nn.Sequential(OrderedDict([
    ('fc1', torch.nn.Linear(model.classifier.in_features, total_class))
]))

model.classifier = classifier

model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [None]:
def fun_loss_lb(logits,true):
       
        ce_loss = nn.CrossEntropyLoss()  


        loss_cls1 = ce_loss(logits[:, :Color_logits], true[:, :Color_logits])  # Cross Entropy Error (for color classification)
        loss_cls2 = ce_loss(logits[:, Color_logits:], true[:, Color_logits:])  # Cross Entropy Error (for cloth classification)
        
        total_loss = loss_cls1 + loss_cls2
        return total_loss


In [None]:
def pseudo_label_calc(logits,threshold):
        max_prob, pseudo_label = torch.max(logits, dim=1)
        
        # mask for
        mask = (max_prob > threshold).float() # [1, 0] [batch_size,]
               
        # mask non-confident prediction
        pseudo_label = pseudo_label.masked_fill(mask == 0, 0) # [3, 0, ...]
        

        return pseudo_label

In [None]:
def fun_loss_ulb(logits_weak,logits_strong,threshold):
        
        logits_color = F.softmax(logits_weak[:,:Color_logits], dim=1)

        pseudo_label_color = pseudo_label_calc(logits_color,threshold)

        logits_cloth = F.softmax(logits_weak[:,Color_logits:], dim=1)

        pseudo_label_cloth = pseudo_label_calc(logits_cloth,threshold)

      
        
        

        ce_loss = nn.CrossEntropyLoss()  


        loss_color = ce_loss(logits_strong[:,:Color_logits], pseudo_label_color)
        loss_cloth = ce_loss(logits_strong[:,Color_logits:], pseudo_label_cloth)



        return loss_color + loss_cloth 

In [None]:
def train_fixMatch(threshold):
    
    model.train()
    
    losses = []
    
    pbar = tqdm(dataloader_lb_dataset)
    
    ulb_iter = iter(dataloader_ulb_dataset)
    
    for batch_lb in pbar:
        
        model.zero_grad()
        
        # labelled data
        x_lb, y = batch_lb
        x_lb, y = x_lb.to(device), y.to(device)
        
        # unlabelled data
        try:
            x_weak, x_strong = next(ulb_iter)
        except StopIteration:
            ulb_iter = iter(dataloader_ulb_dataset)
            x_weak, x_strong = next(ulb_iter)
                
        x_weak, x_strong = x_weak.to(device), x_strong.to(device)
        
        # concat all x
        all_x = torch.cat([x_lb, x_weak, x_strong], dim=0)
        
        # compute logits
        all_logits = model(all_x)
        
        # logits and loss for labelled data
        logits_lb = all_logits[:x_lb.size(0)]
        
        #loss_lb = F.cross_entropy(logits_lb, y)
        
        
        
        # logits for unlabelled data
        logits_ulb = all_logits[x_lb.size(0):]
        logits_weak, logits_strong = torch.chunk(logits_ulb, 2, dim=0)

        # stop gradient for weak augmented
        logits_weak = logits_weak.detach() 

        
        # compute loss 
        lb_loss = fun_loss_lb(logits_lb,y)

        loss = lb_loss + fun_loss_ulb(logits_weak,logits_strong,threshold)

        loss.backward()
        optimizer.step()
        
        loss_item = loss.item()
        
        losses.append(loss_item)
        
        pbar.set_description(f'train loss = {np.array(losses).mean(): .3f}')
        
    return np.array(losses).mean()
@torch.no_grad()
def validate():
    
    model.eval()

    labels_all = []
    logits_all = []
    
    
    for x, y in dataloader_valid_dataset:

        x, y = x.to(device), y.to(device)

        logits = model(x)
        
    
    acc_color = accuracy_score(logits[:, :Color_logits].argmax(axis=1).tolist(),y[:, :Color_logits].argmax(axis=1).tolist())
    acc_cloth = accuracy_score(logits[:, Color_logits:].argmax(axis=1).tolist(),y[:, Color_logits:].argmax(axis=1).tolist())
   
    print('acc_color- {}, acc_cloth- {}'.format(acc_color,acc_cloth))
                    
    return acc_color,acc_cloth

In [None]:
losses = []
accuracies_color = []
accuracies_cloth = []

for epoch in range(num_epochs):
    train_loss = train_fixMatch(threshold=threshold)
    acc_color,acc_cloth = validate()
    accuracies_color.append(acc_color)
    accuracies_cloth.append(acc_cloth)

    losses.append(train_loss)
    
    wandb.log({ "train_loss": train_loss,
                "valid_acc_color":acc_color,
                "valid_acc_cloth":acc_cloth})    

             
    filepath=f"{save_path_checkpoints}/{PROJECT_NAME}_{MODEL_TYPE}-{ARCHITECTURE_NAME}-{epoch+1}_loss-{train_loss}.pt"    
    checkpoint= {
                    "epoch" : epoch+1 ,
                    "model_weight" : model.state_dict(),
                    "optimizer_state" : optimizer.state_dict()
                }
    torch.save(checkpoint,filepath)
    print("{} saved".format(filepath))  

In [None]:
plt.plot(np.arange(num_epochs), losses)
plt.title('traning Loss Vs epoch')
plt.xlabel('epoch')
plt.ylabel('traning Loss')
plt.show()

In [None]:
plt.plot(np.arange(num_epochs), accuracies_color)
plt.title('Validation accuracy color Vs epoch')
plt.xlabel('epoch')
plt.ylabel('validation accuracy color')
plt.show()

In [None]:
plt.plot(np.arange(num_epochs), accuracies_cloth)
plt.title('Validation accuracy scratching Vs cloth')
plt.xlabel('epoch')
plt.ylabel('validation accuracy cloth')
plt.show()

In [None]:
@torch.no_grad()
def validate():
    
    model.eval()

    labels_all = []
    logits_all = []
    
    
    for x, y in dataloader_valid_dataset:

        x, y = x.to(device), y.to(device)

        logits = model(x)
        
    
    acc_color = accuracy_score(logits[:, :Color_logits].argmax(axis=1).tolist(),y[:, :Color_logits].argmax(axis=1).tolist())
    acc_cloth = accuracy_score(logits[:, Color_logits:].argmax(axis=1).tolist(),y[:, Color_logits:].argmax(axis=1).tolist())
   
    

    color_pred = logits[:, :Color_logits].argmax(axis=1).tolist()
    color_y = y[:, :Color_logits].argmax(axis=1).tolist()
    cloth_pred = logits[:, Color_logits:].argmax(axis=1).tolist()
    cloth_y = y[:, Color_logits:].argmax(axis=1).tolist()


    print('acc_color- {}, acc_cloth- {}'.format(acc_color,acc_cloth))
                    
    return    acc_color,acc_cloth
    

In [None]:
import glob
acc = []
for num,path in enumerate(sorted(glob.glob("/content/model/ckpts/*"))):
  checkpoint = torch.load(path)
  model.load_state_dict(checkpoint['model_weight'])
  acc_color,acc_cloth  = validate()
  acc.append((acc_color+acc_cloth)/2)
  #print(path)

sorted(glob.glob("/content/model/ckpts/*"))[np.argmax(acc)]  

print("###")
model.load_state_dict(torch.load(sorted(glob.glob("/content/model/ckpts/*"))[np.argmax(acc)])["model_weight"])
acc_color,acc_cloth  = validate()

In [None]:
acc_color,acc_cloth  = validate()

wandb.log({ "Test_acc_color":acc_color,
            "Test_acc_cloth":acc_cloth})  

In [None]:
wandb.finish()