In [1]:
import sys
sys.path.append('/Users/jodok/02 Code/EE-559-TEAM/Project1')
from src.dlc_practical_prologue import generate_pair_sets
from src.utils import load_class_data, load_target_data, load_all_data
from src.models import *
from src.trainer import Trainer
import matplotlib.pyplot as plt
from torch import nn

In [2]:
# Load data with class as target
dl_train_class, dl_val_class, dl_test_class = load_class_data()

# Load data with larger than as target
dl_train_target, dl_val_target, dl_test_target = load_target_data()

# Load data with class and larget than as targets
dl_train_all, dl_val_all, dl_test_all = load_all_data()

In [3]:
class TailNet(BaseModule):
    def __init__(self, lr=0.001):
        super().__init__(lr)
        self.conv1 = nn.Conv1d(1, 32, kernel_size=10, stride=10)
        self.fc1 = nn.Linear(64, 128)
        self.fc2 = nn.Linear(128, 2)
        self.flat = nn.Flatten(start_dim=1)

    def forward(self, x):
        x = nn.functional.one_hot(x).float()
        x = x.view(-1, 1, 20)
        x = self.conv1(x)
        x = self.flat(x)
        x = nn.functional.relu(x)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.fc2(x)
        return x

    def training_step(self, batch, batch_idx):
        _, x, y = batch
        out = self(x)
        loss = self.loss(out, y)
        return loss
    
    def validation_step(self, batch, batch_idx):
        _, x, y = batch
        out = self(x)
        loss = self.loss(out, y)
        preds = torch.argmax(out, dim=1)
        acc = self.accuracy(preds, y)
        return loss, acc

class TestNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(1, 1, kernel_size=10, stride=10)
    
    def forward(self, x):
        x = nn.functional.one_hot(x).float()
        x = x.view(-1, 1, 20)
        x = self.conv1(x)
        return x

class TailFullyConv(BaseModule):
    def __init__(self, lr=0.001):
        super().__init__(lr)
        self.conv1 = nn.Conv1d(1, 32, kernel_size=10, stride=10)
        self.conv2 = nn.Conv1d(32, 64, kernel_size=1, stride=1)
        self.conv3 = nn.Conv1d(64, 2, kernel_size=1, stride=1)
        self.flat = nn.Flatten(start_dim=1)

    def forward(self, x):
        x = nn.functional.one_hot(x).float()
        x = x.view(-1, 1, 20)
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = self.conv2(x)
        x = nn.functional.relu(x)
        x = self.conv3(x)
        x = self.flat(x)
        return x

    def training_step(self, batch, batch_idx):
        _, x, y = batch
        out = self(x)
        loss = self.loss(out, y)
        return loss
    
    def validation_step(self, batch, batch_idx):
        _, x, y = batch
        out = self(x)
        loss = self.loss(out, y)
        preds = torch.argmax(out, dim=1)
        acc = self.accuracy(preds, y)
        return loss, acc

class TailLinear(BaseModule):
    def __init__(self, lr=0.001):
        super().__init__(lr)
        self.fc1 = nn.Linear(20, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 2)
        self.flat = nn.Flatten(start_dim=1)

    def forward(self, x):
        #x = nn.functional.one_hot(x).float()
        #x = self.flat(x)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.fc2(x)
        x = nn.functional.relu(x)
        x = self.fc3(x)
        return x

    def training_step(self, batch, batch_idx):
        _, x, y = batch
        out = self(x)
        loss = self.loss(out, y)
        return loss
    
    def validation_step(self, batch, batch_idx):
        _, x, y = batch
        out = self(x)
        loss = self.loss(out, y)
        preds = torch.argmax(out, dim=1)
        acc = self.accuracy(preds, y)
        return loss, acc

In [39]:
tail = TailLinear()
#tail = TailFullyConv()
trainer = Trainer(nb_epochs=5)
trainer.fit(tail, dl_train_all, dl_val_all)

# Epoch 1/5:	 loss=0.66	 loss_val=0.62	 acc_val=69.64
# Epoch 2/5:	 loss=0.53	 loss_val=0.39	 acc_val=94.64
# Epoch 3/5:	 loss=0.28	 loss_val=0.15	 acc_val=99.55
# Epoch 4/5:	 loss=0.11	 loss_val=0.07	 acc_val=100.0
# Epoch 5/5:	 loss=0.05	 loss_val=0.03	 acc_val=100.0


In [8]:
lenet = LeNet()
target_net = TailLinear()
siamese = Siamese(auxiliary=lenet, target=target_net, weight_aux=0.5)
trainer = Trainer(nb_epochs=25)
trainer.fit(siamese, dl_train_all, dl_val_all)

# Epoch 1/15:	 loss=2.53	 loss_val=0.68	 acc_val=62.5
# Epoch 2/15:	 loss=1.41	 loss_val=0.62	 acc_val=72.32
# Epoch 3/15:	 loss=0.98	 loss_val=0.53	 acc_val=72.32
# Epoch 4/15:	 loss=0.7	 loss_val=0.51	 acc_val=74.55
# Epoch 5/15:	 loss=0.63	 loss_val=0.47	 acc_val=78.57
# Epoch 6/15:	 loss=0.64	 loss_val=0.49	 acc_val=77.23
# Epoch 7/15:	 loss=0.49	 loss_val=0.52	 acc_val=75.45
# Epoch 8/15:	 loss=0.43	 loss_val=0.44	 acc_val=82.14
# Epoch 9/15:	 loss=0.37	 loss_val=0.44	 acc_val=81.25
# Epoch 10/15:	 loss=0.34	 loss_val=0.42	 acc_val=83.04
# Epoch 11/15:	 loss=0.33	 loss_val=0.41	 acc_val=82.14
# Epoch 12/15:	 loss=0.3	 loss_val=0.42	 acc_val=83.04
# Epoch 13/15:	 loss=0.27	 loss_val=0.39	 acc_val=86.16
# Epoch 14/15:	 loss=0.24	 loss_val=0.39	 acc_val=87.5
# Epoch 15/15:	 loss=0.22	 loss_val=0.32	 acc_val=90.62


In [9]:
trainer.fit(siamese, dl_train_all, dl_val_all)

# Epoch 1/15:	 loss=0.22	 loss_val=0.31	 acc_val=93.3
# Epoch 2/15:	 loss=0.19	 loss_val=0.3	 acc_val=92.41
# Epoch 3/15:	 loss=0.16	 loss_val=0.26	 acc_val=92.86
# Epoch 4/15:	 loss=0.22	 loss_val=0.3	 acc_val=90.18
# Epoch 5/15:	 loss=0.19	 loss_val=0.34	 acc_val=88.39
# Epoch 6/15:	 loss=0.11	 loss_val=0.28	 acc_val=91.96
# Epoch 7/15:	 loss=0.1	 loss_val=0.29	 acc_val=92.86
# Epoch 8/15:	 loss=0.08	 loss_val=0.28	 acc_val=93.3
# Epoch 9/15:	 loss=0.06	 loss_val=0.3	 acc_val=93.75
# Epoch 10/15:	 loss=0.05	 loss_val=0.3	 acc_val=93.75
# Epoch 11/15:	 loss=0.05	 loss_val=0.31	 acc_val=93.3
# Epoch 12/15:	 loss=0.06	 loss_val=0.43	 acc_val=90.18
# Epoch 13/15:	 loss=0.11	 loss_val=0.33	 acc_val=93.75
# Epoch 14/15:	 loss=0.1	 loss_val=0.27	 acc_val=93.75
# Epoch 15/15:	 loss=0.09	 loss_val=0.31	 acc_val=93.75


In [36]:
for batch in dl_train_all:
    _, x, y = batch
    break

In [28]:
test_net = TestNet()
out = test_net(x)

In [35]:
out.view(-1, 2).size()

torch.Size([32, 2])

In [20]:
torch.argmax(out, dim=2)
y

tensor([0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0,
        0, 1, 0, 0, 1, 0, 1, 0])

In [37]:
x = nn.functional.one_hot(x, num_classes=10)
x.size()

torch.Size([32, 2, 10])

In [13]:
x.view(-1, 1, 20).float().size()

torch.Size([32, 1, 20])

In [23]:
nn.Flatten()(x).size()

torch.Size([32, 20])