<a href="https://colab.research.google.com/github/perrin-isir/tp_classif_images/blob/main/tp_classif_images.ipynb"> <img align="left" src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab" title="Open in Google Colaboratory"></a>
<a id="raw-url" href="https://raw.githubusercontent.com/perrin-isir/tp_classif_images/main/tp_classif_images.ipynb" download> <img align="left" src="https://img.shields.io/badge/Github-Download%20(Right%20click%20%2B%20Save%20link%20as...)-blue" alt="Download (Right click + Save link as)" title="Download Notebook"></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
# from torch.autograd import Variable
import torchvision.transforms as transforms
device = torch.device("cpu")
!pip install deeplake[enterprise]
import deeplake
ds = deeplake.load('hub://activeloop/cacd')

In [None]:
import os
tmp_dir = os.path.join(os.path.expanduser("~"), "tmp_data")
!wget -P {tmp_dir} "http://www.umiacs.umd.edu/~sirius/CACD/celebrity2000_meta.mat"
import scipy.io
mat = scipy.io.loadmat(os.path.join(os.path.expanduser("~"), "tmp_data", "celebrity2000_meta.mat"))
estimated_ages = mat["celebrityImageData"][0, 0][0].flatten()
age_threshold = 35

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.25)
        self.dropout3 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(((((75-2)//2-2)//2)**2)*64, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 2)

    def forward(self, x):
        x = F.relu(self.conv1(x.view(-1, 3, 75, 75)))
        x = self.dropout1(F.max_pool2d(x, 2))
        x = F.relu(self.conv2(x))
        x = self.dropout2(F.max_pool2d(x, 2))
        x = torch.flatten(x, 1)
        x = self.dropout3(F.relu(self.fc1(x)))
        x = self.fc2(x)
        x = self.fc3(x)
        return x


net = Net().to(device)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.to(device)),
    transforms.Resize(100),
    transforms.RandomCrop(75, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_no_modif = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.to(device)),
])

In [None]:
import matplotlib.pyplot as plt
import numpy as np
# Plotting tool

def plotdata(data, indexes, model=None, original=True):
    l = []
    l_original = []
    for elt in indexes:
        t = transform(data.images[elt].numpy())
        l.append(t)
        tbis = transform_no_modif(data.images[elt].numpy())
        l_original.append(tbis)
    cte = 30.0/9.0
    k = len(l)
    n = int(np.sqrt(cte * k))
    m = int(k/(n * 1.0))
    if (n*m<k):
        m = m+1
    width=20
    mult = 2
    if model is None:
        mult = 1
    f, ax = plt.subplots(mult*m,n,squeeze=False, figsize=(width,int(width*m/(n*2.0))))
    for i in range(m):
        if model is not None:
            for j in range(n):
                if(j+n*i < k):
                    ax[mult*i+1,j].tick_params(axis=u'both', which=u'both',length=0)
                    ax[mult*i+1,j].set_ylim([-0.5,10.5])
                    ax[mult*i+1,j].set_xlim([-1.5,12.5])
                    ax[mult*i+1,j].set_xticks([])
                    ax[mult*i+1,j].set_xticks(np.arange(0.5,9.5,1), minor=True)
                    ax[mult*i+1,j].set_yticks([])
                    ax[mult*i+1,j].grid(False)
                    ax[mult*i+1,j].set_aspect('equal')
                    L = F.softmax(model(l[j+n*i]), dim=1).cpu().data.numpy().flatten()
                    C = [(0.9, 0.1, 0.0, 1.0), (0.0, 0.1, 0.9, 1.0)]
                    ax[mult*i+1,j].barh([1,5], [z * 10.0 for z in reversed(L)], color=C)
                    for idx in range(len(L)):
                        if L[idx]>0.02:
                            ax[mult*i+1,j].text(10.0*L[idx]+0.15,(len(L)-1-idx+0.1)*4.0,idx)
                else:
                    ax[mult*i+1,j].axis('off')
        # ------------------
        for j in range(n):
            if(j+n*i < k):
                ax[mult*i+0,j].tick_params(axis=u'both', which=u'both',length=0)
                ax[mult*i+0,j].set_xticks([])
                ax[mult*i+0,j].set_yticks([])
                ax[mult*i+0,j].grid(False)
                ax[mult*i+0,j].set_xticklabels([])
                ax[mult*i+0,j].set_yticklabels([])
                img = None
                if original:
                    img = l_original[j + n * i]
                else:
                    img = l[j + n * i]
                if img.shape[0] == 3:
                    N = img[:, :, :]
                    # from IPython import embed
                    # embed()
                    ax[mult * i + 0, j].imshow(
                        N.permute(1, 2, 0),
                    )
                else:
                    N = img[0, :, :]
                    ax[mult*i+0,j].matshow(N, cmap='Greys', )
            else:
                ax[mult*i+0,j].axis('off')
    plt.show()

In [None]:
net(transform(ds1.images[0].numpy()))

In [None]:
train_ds, test_ds, trash_ds = ds.random_split(
    [20000, 2000, len(ds) - 22000]
)

In [None]:
train_dataloader = train_ds.dataloader()\
    .transform({'images': transform, 'keypoints': None, "index": None})\
    .batch(40)\
    .shuffle()\
    .pytorch(decode_method={'images': 'pil'})

test_dataloader = test_ds.dataloader()\
    .transform({'images': transform, 'keypoints': None, "index": None})\
    .batch(100)\
    .shuffle()\
    .pytorch(decode_method={'images': 'pil'}) 

In [None]:
# evaluation on a batch of test data:
def evaluate(model, dataloader):
    batch_enum = enumerate(dataloader)
    batch_idx, testdata = next(batch_enum)
    indices = testdata["index"].flatten()
    testdata = testdata["images"]
    testtargets = torch.Tensor(estimated_ages[indices] > age_threshold).long().to(device)
    model = model.eval()
    outp = torch.argmax(model(testdata), dim=1)
    t = torch.sum(outp == testtargets)
    result = t * 100.0 / len(indices)
    model = model.train()
    print(f"{t} correct on {len(indices)} ({result.item()} %)")
    return result.item()

In [None]:
evaluate(net, test_dataloader)

In [None]:
# iteratively train on 50 batches:
def train_epoch(model, optimizer, dataloader):
    batch_enum = enumerate(dataloader)
    for i_count in range(50):
        batch_idx, traindata = next(batch_enum)
        indices = traindata["index"].flatten()
        traindata = traindata["images"]
        traintargets = torch.Tensor(estimated_ages[indices] > age_threshold).long().to(device)
        outp = model(traindata)
        loss = F.cross_entropy(outp, traintargets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if not i_count % 10:
            print(f"    step {i_count}")

In [None]:
learning_rate = 0.01
momentum = 0.5
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum)

In [None]:
train_epoch(net, optimizer, train_dataloader)

In [None]:
num_epochs = 1
for j in range(num_epochs):
    print(f"epoch {j} / {num_epochs}")
    train_epoch(net, optimizer, data)
    evaluate(net, data)
    torch.save(net.state_dict(), './data/model_TP.pt')

In [None]:
indices = np.random.choice(range(data.num_test_samples),4 )

In [None]:
plotdata(data, indices, net, original=False)