In [1]:
from make_dataloader import get_dataloader
from training_test_loops import training_loop, test_loop
import torch
from torch import nn
from torch.autograd import Function
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchsummary import summary
import time
torch.manual_seed(222222)
np.random.seed(2222)

In [2]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3, stride=(1,1), padding='same')
        self.bn1 = nn.BatchNorm2d(8)
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, stride=(1,1), padding='same')
        self.bn2 = nn.BatchNorm2d(16)
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=(1,1), padding='same')
        self.bn3 = nn.BatchNorm2d(32)
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout = nn.Dropout(0.)
        
        self.fc1 = nn.Linear(in_features=32*8*8, out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=64)
        self.fc3 = nn.Linear(in_features=64, out_features=4)
        
    def forward(self, x):
        x = x.view(-1, 3, 64, 64)
        
        x = self.dropout(self.pool(self.bn1(F.relu(self.conv1(x)))))
        x = self.dropout(self.pool(self.bn2(F.relu(self.conv2(x)))))
        x = self.dropout(self.pool(self.bn3(F.relu(self.conv3(x)))))
        #x = self.pool(self.bn4(F.relu(self.conv4(x))))
        
        x = x.view(-1, 32*8*8)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x), dim=1)
        x = x.view(-1, 4)
        return x

model = NeuralNetwork().cuda()

pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(pytorch_total_params)
#summary(model, (3, 128, 128));

276932


In [3]:
root_dir = 'C:/Users/paxso/galclass_da/gal_img_full/gal_img_full/'
csv_file = root_dir+"gal_list.csv"
domain = 'source'
batch_size = 32
data = get_dataloader(csv_file, root_dir, domain, batch_size, train_size=.90, val_size=.05, test_size=.05)
train_dataset, val_dataset, test_dataset, train_dataloader, val_dataloader, test_dataloader = data

In [4]:
epochs = 120
lr = 5.0e-4
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss().cuda()

In [5]:
for epoch in range(epochs):
    print('epoch: ', epoch)
    start_time = time.time()
    training_accuracy, training_loss, val_accuracy, val_loss = training_loop(train_dataloader, 
                                                                         val_dataloader, model, optimizer, loss_fn)
    print(training_accuracy, training_loss, val_accuracy, val_loss)
    print(time.time() - start_time)

epoch:  0
0.6811832006981063 1.0560866132414142 0.7179108485499462 1.020265246921343
82.33822822570801
epoch:  1
0.7272984926572045 1.0116679248917928 0.7369763694951665 1.0022008355083383
81.42657589912415
epoch:  2
0.7392021002856568 0.9995690956889587 0.7373791621911923 1.0008628672796258
81.79783487319946
epoch:  3
0.7467872938684488 0.9923468488782005 0.749328678839957 0.9904257872585575
81.85513496398926
epoch:  4
0.7529703081065358 0.9866427429678332 0.7502685284640171 0.9899484126352956
82.50599360466003
epoch:  5
0.7551332443297508 0.9843429432137201 0.7518796992481203 0.9886822350035409
82.26624536514282
epoch:  6
0.7586088590884342 0.9815299282779557 0.7592642320085929 0.9801751633570429
82.2315583229065
epoch:  7
0.7620993906486571 0.9775969755535763 0.7671858216970999 0.9721794980278342
84.13679933547974
epoch:  8
0.7621888914578936 0.9778848774677814 0.7624865735767992 0.9782766976069994
82.58145093917847
epoch:  9
0.764978333345764 0.9751727931539312 0.7607411385606875 0

0.8031653452866636 0.9386650130401648 0.7837003222341569 0.9573108124119017
89.69862008094788
epoch:  80
0.8042617301998105 0.9375530812364774 0.7914876476906552 0.9486500556888499
89.47496247291565
epoch:  81
0.7995331041118163 0.9423399378832314 0.7996777658431794 0.941840801116223
88.559490442276
epoch:  82
0.7998538153449137 0.9421088826115774 0.7921589688506981 0.9496103097952486
85.06131982803345
epoch:  83
0.7963185333800726 0.9455479576650269 0.7941729323308271 0.9486249345054953
83.61223220825195
epoch:  84
0.8016736651327222 0.9403153569857658 0.7909505907626209 0.950305447813779
86.66088914871216
epoch:  85
0.8039484773674829 0.9381249597931819 0.7926960257787325 0.9496435420707572
87.33567595481873
epoch:  86
0.8037471005467007 0.9383780510300384 0.784640171858217 0.9581514868101848
89.33780026435852
epoch:  87
0.8034338477143731 0.9383954996709892 0.795515574650913 0.9470712069278111
86.03695917129517
epoch:  88
0.8024269636104626 0.9394619239827614 0.7948442534908701 0.94

In [6]:
torch.save(model.state_dict(), 'C:/Users/paxso/galclass_da/79_sourceonly_network.pt')

In [7]:
test_accuracy, loss = test_loop(test_dataloader, model, loss_fn)
print(test_accuracy, loss)

0.800483351235231 0.9423054456710815


In [8]:
root_dir = 'C:/Users/paxso/galclass_da/gal_img_full/gal_img_full/'
csv_file = root_dir + 'hsc_dataframe.csv'
domain = 'target'
batch_size = 32
target_data = get_dataloader(csv_file, root_dir, domain, batch_size, train_size=.01, val_size=.01, test_size=.98)
target_train_dataset, target_val_dataset, target_test_dataset, target_train_dataloader, target_val_dataloader, target_test_dataloader = target_data
target_test_accuracy, target_loss = test_loop(target_test_dataloader, model, loss_fn)
print(target_test_accuracy, target_loss)

0.6524872152487216 1.0892262398311408


In [None]:
# Define reversal layer for WANN
class ReverseLayerF(Function):

    @staticmethod
    def forward(ctx, x):

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg()

        return output