In [231]:
import torchvision.models as models
from torchvision import transforms
import pytorch_lightning as pl
import torch

import pandas as pd
import glob
import json
import tqdm

import matplotlib.pyplot as plt

In [13]:
transform = transforms.Compose(
[
    transforms.ToTensor(),
    transforms.Resize((28, 28)),
    # transforms.Normalize((0.1307,), (0.3081,))
])

In [57]:
train_dataset_classes = ['1.10',
 '1.16',
 '1.17',
 '1.2',
 '1.20.1',
 '1.20.2',
 '1.22',
 '1.25',
 '1.34.3',
 '2.1',
 '2.3.1',
 '2.4',
 '3.1',
 '3.11',
 '3.12',
 '3.13',
 '3.2',
 '3.20',
 '3.24',
 '3.25',
 '3.27',
 '3.28',
 '3.4',
 '4.1.1',
 '4.1.4',
 '4.2.1',
 '4.2.3',
 '5.15.1',
 '5.15.3',
 '5.15.5',
 '5.15.6',
 '5.16',
 '5.18',
 '5.19.1',
 '5.19.2',
 '5.21',
 '5.23.1',
 '5.5',
 '5.6',
 '5.9',
 '6.10.1',
 '6.11',
 '6.12',
 '6.13',
 '6.16',
 '7.12',
 '7.3',
 '8.2.1',
 '8.2.2',
 '8.2.3',
 '8.2.4',
 '8.22.1',
 '8.22.2',
 '8.22.3',
 '8.23']

In [58]:
len(train_dataset_classes)

55

In [59]:
class ImagenetTransferLearning(pl.LightningModule):
    def __init__(self):
        super().__init__()

        backbone = models.resnet50(pretrained=True)
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = torch.nn.Sequential(*layers)

        num_target_classes = len(train_dataset_classes)
        self.classifier = torch.nn.Linear(num_filters, num_target_classes)

        
    def forward(self, x):
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        x = self.classifier(representations)
        return x
    
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self.forward(x)
        loss_value = loss(pred, y)
        acc.update(pred, y)
        self.log('loss/train', loss_value, on_epoch=True)
        self.log('accuracy/train', acc.compute(), on_epoch=True)
        return loss_value

    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        pred = self.forward(x)
        loss_value = loss(pred, y)
        acc.update(pred, y)
        self.log('loss/valid', loss_value, on_epoch=True)
        self.log('accuracy/valid', acc.compute(), on_epoch=True)
        return loss_value
      

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1.0e-3)
        lr_scheduler = {
            "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.3, verbose=True),
            "interval": "epoch",
            "frequency": 1,
            "monitor": "accuracy/valid",
            "strict": True,
            "name": None,
        }
        return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler}

In [64]:
signs = ['3.24', '1.16', '5.15.5', '5.19.1', '5.19.2', '1.20.1', '8.23',
'2.1', '4.2.1', '8.22.1', '6.16', '1.22', '1.2', '5.16', '3.27',
'6.10.1', '8.2.4', '6.12', '5.15.2', '3.13', '3.1', '3.20', '3.12',
'7.14.2', '5.23.1', '2.4', '5.6', '4.2.3', '8.22.3', '5.15.1',
'7.3', '3', '2.3.1', '3.11', '6.13', '5.15.4', '8.2.1', '1.34.3',
'8.2.2', '5.15.3', '1.17', '4.1.1', '4.1.4', '3.25', '1.20.2',
'8.22.2', '6.9.2', '3.2', '5.5', '5.15.7', '7.12', '8.2.3',
'5.24.1', '1.25', '3.28', '5.9.1', '5.15.6', '8.1.1', '1.10',
'6.11', '3.4', '6.10', '6.9.1', '8.2.5', '5.15', '4.8.2', '8.22',
'5.21', '5.18']

In [159]:
model = ImagenetTransferLearning.load_from_checkpoint(r".\model_55_classes.ckpt")



In [238]:
data_type = "test"

In [240]:
data = pd.read_csv(f"{data_type}.csv")

In [247]:
solution = data.copy()

solution.drop("img", axis=1, inplace=True)

for i in range(1, 9):
    solution[f"sing{i}"] = 0

solution.set_index("id", inplace=True)
    
for row in tqdm.tqdm(data.iterrows(), total=len(data)):
    url = f"{data_type}/{row[1].id}/*.jpg"
    s = set()

    for img_path in glob.glob(url):
        image = transform(plt.imread(img_path))
        r = model(image.reshape(1, 3, 28, 28))

        rs = r.sort()
        if (rs.values[0][1:] - rs.values[0][:-1]).max().item() > 12:
            j = r.argmax().item()
            try:
                s.add(signs.index(train_dataset_classes[j]) + 1)
            except:
                pass
            
    for i, s2 in enumerate(sorted(s)):
        if i > 7:
            break
        solution.loc[row[1].id, f"sing{i + 1}"] = s2
    
solution.reset_index(inplace=True)

100%|████████████████████████████████████████████████████████████████████████████████| 388/388 [02:46<00:00,  2.33it/s]


In [248]:
solution.to_csv("solution.csv", index=False)