In [None]:
import torch 
import numpy as np
from torchvision import transforms, utils
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from util.utils import *
from data.dataloader import *
from model.Pred_net import Pred_Net
from model.Ind_net import Ind_Net
from model.Multi_net import Multi_net

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
# train the Ind_net
model = Ind_Net().to(device)
epochs = 80
lr = 1e-5
weight_decay=5e-4
early_stop = 10
batch_size = 16
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=lr,weight_decay=weight_decay)
scheduler = optim.lr_scheduler.StepLR(optimizer,5,gamma=0.7)
train_data = Ind_dataset(dir_path='data/patches/train',transform=transforms.ToTensor())
val_data = Ind_dataset(dir_path='data/patches/validate',transform=transforms.ToTensor())
test_data = Ind_dataset(dir_path='data/patches/test',transform=transforms.ToTensor())
train_dataloader = DataLoader(train_data,batch_size=batch_size,shuffle=True)
val_dataloader = DataLoader(val_data,batch_size=batch_size,shuffle=True)
test_dataloader = DataLoader(test_data,batch_size=batch_size,shuffle=True)

In [None]:
best_acc = 0
es = 0
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer, device)
    acc = evaluate(val_dataloader,  model, loss_fn, device,'validate')
    scheduler.step()
    if acc >best_acc :
        best_acc = acc
        es = 0
        torch.save(model,f'checkpoint/Ind/Ind_Net_{(100*acc):>0.1f}%.pth')
    else:

        es = es+1
    if es==early_stop:
        print("Early stopping with best_acc: ", best_acc)
        break

In [None]:
# train the Pred_net
model = Pred_Net().to(device)
epochs = 80
lr = 1e-5
weight_decay=5e-4
early_stop = 10
batch_size = 16
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=lr,weight_decay=weight_decay)
scheduler = optim.lr_scheduler.StepLR(optimizer,5,gamma=0.7)
train_data = Pred_dataset(dir_path='data/patches/train',transform=transforms.ToTensor())
val_data = Pred_dataset(dir_path='data/patches/validate',transform=transforms.ToTensor())
test_data = Pred_dataset(dir_path='data/patches/test',transform=transforms.ToTensor())
train_dataloader = DataLoader(train_data,batch_size=batch_size,shuffle=True)
val_dataloader = DataLoader(val_data,batch_size=batch_size,shuffle=True)
test_dataloader = DataLoader(test_data,batch_size=batch_size,shuffle=True)

In [None]:
best_acc = 0
es = 0
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer, device)
    acc = evaluate(val_dataloader,  model, loss_fn, device,'validate')
    scheduler.step()
    if acc >best_acc :
        best_acc = acc
        es = 0
        torch.save(model,f'checkpoint/Pred/Pred_Net_{(100*acc):>0.1f}%.pth')
    else:
        es = es+1
    if es==early_stop:
        print("Early stopping with best_acc: ", best_acc)
        break

In [None]:
# train the Multi_net
Ind = torch.load('checkpoint\Ind\Ind_Net_84.9%.pth').to(device)
Pred = torch.load('checkpoint\Pred\Pred_Net_85.4%.pth').to(device)
model = Multi_net(Ind,Pred).to(device)
epochs = 80
lr = 1e-4
weight_decay=5e-4
early_stop = 10
batch_size = 16
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=lr,weight_decay=weight_decay)
scheduler = optim.lr_scheduler.StepLR(optimizer,5,gamma=0.7)
train_data = Multi_dataset(dir_path='data/patches/train',transform=transforms.ToTensor())
val_data = Multi_dataset(dir_path='data/patches/validate',transform=transforms.ToTensor())
test_data = Multi_dataset(dir_path='data/patches/test',transform=transforms.ToTensor())
train_dataloader = DataLoader(train_data,batch_size=batch_size,shuffle=True)
val_dataloader = DataLoader(val_data,batch_size=batch_size,shuffle=True)
test_dataloader = DataLoader(test_data,batch_size=batch_size,shuffle=True)

In [None]:
best_acc = 0
es = 0
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_multi(train_dataloader, model, loss_fn, optimizer, device)
    acc = evaluate_multi(val_dataloader,  model, loss_fn, device,'validate')
    scheduler.step()
    if acc >best_acc :
        best_acc = acc
        es = 0
        torch.save(model,f'checkpoint/Mutli/Multi_Net_{(100*acc):>0.1f}%.pth')
    else:
        es = es+1
    if es==early_stop:
        print("Early stopping with best_acc: ", best_acc)
        break
print("Done!")

In [None]:
evaluate(test_dataloader,  model, loss_fn, device,'test')