# Preliminaries

## Install and import libraries 

In [1]:
import torch
import torch.nn.functional as F
import os
import ast
import csv
import pandas as pd

## Set computation engine

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


## Connect to drive

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


#Load precomputed text features 

In [4]:
class TextDataset(torch.utils.data.Dataset):

    def __init__(self):
      #Load pre-computed tensors
      self.Xtrain_item_name = torch.load('/content/drive/MyDrive/data_rakuten/Xtrain_item_name.pt')
      self.Xtrain_item_caption = torch.load('/content/drive/MyDrive/data_rakuten/Xtrain_item_caption.pt')
      self.Ytrain_label = torch.load('/content/drive/MyDrive/data_rakuten/Ytrain_label.pt')
        
    def __len__(self):
        return self.Xtrain_item_caption.shape[1]

    def __getitem__(self, idx):

        return  torch.cat((self.Xtrain_item_name[:,idx],self.Xtrain_item_caption[:,idx]),0),self.Ytrain_label[:,idx]

trainSet= TextDataset()
trainLoader = torch.utils.data.DataLoader(trainSet, batch_size=10,shuffle=True, num_workers=12)

# Our prediction model

## Create the model

In [5]:
class CustomModel(torch.nn.Module):

    def __init__(self):
        super(CustomModel, self).__init__()

        self.dropout = torch.nn.Dropout(0.1)
        self.fc1 = torch.nn.Linear(1536, 128)
        self.fc2 = torch.nn.Linear(128, 19)


    def forward(self, text_features):
        text_features = F.relu(self.fc1(text_features))
        text_features = self.dropout(text_features)
        logits = self.fc2(text_features)

        return logits

model=CustomModel()
model.cuda()

CustomModel(
  (dropout): Dropout(p=0.1, inplace=False)
  (fc1): Linear(in_features=1536, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=19, bias=True)
)

## Train the model

In [6]:
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.15)

# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    model.train()
    train_loss = 0

    for batch_idx, (inputs, targets) in enumerate(trainLoader):
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()


        train_loss += loss.item()

        if batch_idx%1000==0:
            print('{:.0f}%|Train Loss: {:.5f} '.format(100*batch_idx/(len(trainLoader)+1),train_loss/(batch_idx+1)))

In [7]:
# training loop
for epoch in range(5):
    train(epoch)
    scheduler.step()


Epoch: 0
0%|Train Loss: 0.70457 
5%|Train Loss: 0.60461 
9%|Train Loss: 0.52776 
14%|Train Loss: 0.47148 
19%|Train Loss: 0.43578 
24%|Train Loss: 0.41197 
28%|Train Loss: 0.39551 
33%|Train Loss: 0.38341 
38%|Train Loss: 0.37446 
42%|Train Loss: 0.36756 
47%|Train Loss: 0.36135 
52%|Train Loss: 0.35643 
57%|Train Loss: 0.35260 
61%|Train Loss: 0.34925 
66%|Train Loss: 0.34630 
71%|Train Loss: 0.34395 
75%|Train Loss: 0.34154 
80%|Train Loss: 0.33929 
85%|Train Loss: 0.33766 
90%|Train Loss: 0.33612 
94%|Train Loss: 0.33482 
99%|Train Loss: 0.33351 

Epoch: 1
0%|Train Loss: 0.25280 
5%|Train Loss: 0.30643 
9%|Train Loss: 0.30905 
14%|Train Loss: 0.30878 
19%|Train Loss: 0.30895 
24%|Train Loss: 0.30854 
28%|Train Loss: 0.30775 
33%|Train Loss: 0.30702 
38%|Train Loss: 0.30715 
42%|Train Loss: 0.30751 
47%|Train Loss: 0.30775 
52%|Train Loss: 0.30824 
57%|Train Loss: 0.30828 
61%|Train Loss: 0.30844 
66%|Train Loss: 0.30867 
71%|Train Loss: 0.30863 
75%|Train Loss: 0.30849 
80%|Train L

In [8]:
#Save weights
model_file = "/content/drive/MyDrive/data_rakuten/textmodel.pth"
torch.save(model.state_dict(), model_file)

In [9]:
#Load weights
model_file = "/content/drive/MyDrive/data_rakuten/textmodel.pth"
state_dict = torch.load(model_file)
model.load_state_dict(state_dict)

<All keys matched successfully>

# Generate csv file for submission

In [10]:
class TestDataset(torch.utils.data.Dataset):

    def __init__(self,):

        self.Xtest_item_name = torch.load('/content/drive/MyDrive/data_rakuten/Xtest_item_name.pt')
        self.Xtest_item_caption = torch.load('/content/drive/MyDrive/data_rakuten/Xtest_item_caption.pt')
        
        

    def __len__(self):
        return self.Xtest_item_name.shape[1]


    #all this processing needs to be done here because the output of __getitem__ needs to have a fixed size to use a BS>1
    def __getitem__(self, idx):
 
        return  torch.cat((self.Xtest_item_name[:,idx] ,self.Xtest_item_caption[:,idx]),0)

testSet= TestDataset()
testLoader = torch.utils.data.DataLoader(testSet, batch_size=1,shuffle=False, num_workers=2)

In [11]:
inv_dico_labels={ 0: "Beige",1:"Black",2:"Blue",3:"Brown",4:"Burgundy",5:"Gold",6:"Green",7:"Grey",
                 8:"Khaki",9:"Multiple Colors",10:"Navy",11:"Orange",12:"Pink",
                 13:"Purple",14:"Red",15:"Silver",16:"Transparent",17:"White",18:"Yellow"}

model.eval()

#Write prediction in the submission.csv file

with open('/content/drive/MyDrive/data_rakuten/submission.csv', 'w') as csvfile:
    spamwriter = csv.writer(csvfile, delimiter=',')
    spamwriter.writerow([',color_tags,'])
    with torch.no_grad():
        for batch_idx, inputs in enumerate(testLoader):
            inputs = inputs.to(device)
            outputs = model(inputs)
            prediction=[]
            for indice,logits in enumerate(outputs.squeeze(0)):
                if logits>-1: #put the tag if the proba is greater than 0.5
                    prediction.append(inv_dico_labels[indice]) 
            
            if len(prediction)>1:
                spamwriter.writerow(['{},"{}"'.format(batch_idx,prediction)])
            else:
                spamwriter.writerow(['{},{}'.format(batch_idx,prediction)])
            if batch_idx>300:
              break