In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import gc
import cv2
from torch.utils.data import Dataset
import multiprocessing
num_cpu_cores = multiprocessing.cpu_count()
from torch.utils.data import DataLoader
import os
from copy import deepcopy
import math

In [2]:
def process_dir_list(input_list):
    input_list.sort()
    if input_list[0] == ".DS_Store":
        del input_list[0]
    if input_list[-1] == ".DS_Store":
        del input_list[-1]

In [3]:
body_parts = [
    "nose",
    "left_ear",
    "right_ear",
    "top_neck",
    "left_hip",
    "right_hip",
    "tail_base",
    "tail_end"
]

num_body_parts = len(body_parts)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [4]:
class Camera:
    def __init__(self, name, K, D, T, fisheye):
        self.name = name
        self.K = K
        self.D = D
        self.set_T(T)
        self.fisheye = fisheye
        print(name)
        print(self.rvec)
        print(self.tvec)
        
    def set_T(self, new_T):
        self.T = new_T
        self.update_R_mx()
        self.update_R_vec()
        self.update_T_vec()
        self.update_cam_pos()
    
    def update_R_mx(self):
        self.R = self.T[:3,:3]
    
    def update_R_vec(self):
        self.rvec = cv2.Rodrigues(self.R)[0].flatten()
        
    def update_T_vec(self):
        self.tvec = self.T[:3,3:4].flatten()
        
    def update_cam_pos(self):
        left = -1 * self.R.T
        self.cam_pos = left @ self.tvec
        
N_cams = 4

name1 = "17391290"
name2 = "17391304" # top cam
name3 = "19412282"
name4 = "21340171"
names = [name1,name2,name3,name4]

K1 = np.array([[365.99998268853614, 0.0, 653.5836307711236],[0.0, 362.54269203501826, 524.2898251351193],[ 0.0, 0.0, 1.0]], dtype=np.float64)
K2 = np.array([[1681.4567542003983, 0.0, 639.5656045632381],[0.0, 1684.09437531831, 511.39077900281455],[0.0, 0.0, 1.0]], dtype=np.float64)
K3 = np.array([[365.99998268853614, 0.0, 653.5836307711236],[0.0, 362.54269203501826, 524.2898251351193],[0.0, 0.0, 1.0]], dtype=np.float64)
K4 = np.array([[365.99998268853614, 0.0, 653.5836307711236],[0.0, 362.54269203501826, 524.2898251351193],[0.0, 0.0, 1.0]], dtype=np.float64)
Ks = [K1,K2,K3,K4]

D1 = np.array([0.06606845837702091, -0.06201870735392536, 0.05175078971151844, -0.013604518309950266], dtype=np.float64)
D2 = np.array([-0.09329021100102836, -0.015373120930106288, 0.01682495507715212, -0.0272950838970061, 2.051106767874866, 0.09422046169167524, -0.004912591698594109, 1.438207444196221, 0.026447670217746067, 0.005462499898000899, -0.015115712037627299, -0.012697358649902119, -0.029058318376981435, -0.06767148991259063],dtype=np.float64)
D3 = np.array([0.06606845837702091, -0.06201870735392536, 0.05175078971151844, -0.013604518309950266], dtype=np.float64)
D4 = np.array([0.06606845837702091, -0.06201870735392536, 0.05175078971151844, -0.013604518309950266], dtype=np.float64)
Ds = [D1,D2,D3,D4]

BigT1 = np.array([[ -0.9988512200551544, 0.043491384384982415, -0.02011814302084402, -0.763123986941524], [ -0.028159114981956473, -0.1930294179994544, 0.9807888192828389, -31.06215637460763], [ 0.038772470101379686, 0.9802286178596653, 0.1940323485689126, 1.594259421824204], [ 0.0, 0.0, 0.0, 1.0]],dtype=np.float64)
BigT2 = np.array([[ 0.9999569195887527, 0.001794420968667116, -0.009107086249737644, 0.2888088672435886], [ -0.001752078231171209, 0.9999876292272218, 0.004655278121889946, -0.15001638716214752], [ 0.009115327116719876, -0.004639121243026419, 0.9999476933148291, 0.034206925537038246], [ 0.0, 0.0, 0.0, 1.0]],dtype=np.float64)
BigT3 = np.array([[ 0.44291242408315173, -0.895954991190699, -0.033064155111858284, -0.12309494502546652], [ -0.10124781416453565, -0.08662625383473538, 0.991082626360418, -30.298525509057924], [ -0.8908296496635539, -0.4356151350757163, -0.1290813285230491, 7.976794111722069], [ 0.0, 0.0, 0.0, 1.0]],dtype=np.float64)
BigT4 = np.array([[ 0.5093793508741311, 0.8569679205886, -0.07834960105259109, 3.998798335582383], [ 0.02952233249985642, 0.07359067539717384, 0.9968514655546011, -30.49690829849623], [ 0.860035527630584, -0.5100886154156347, 0.012185878570874109, 5.645966786117385], [ 0.0, 0.0, 0.0, 1.0]],dtype=np.float64)
BigTs = [BigT1,BigT2,BigT3,BigT4]

cams = []

for i in range(4):
    fisheye = True
    if i == 1:
        fisheye = False
    name = names[i]
    K = Ks[i]
    D = Ds[i]
    BigT = BigTs[i]
    cam = Camera(name, K, D, BigT, fisheye)
    cams.append(cam)

17391290
[-0.01869511 -1.96530515 -2.39112972]
[ -0.76312399 -31.06215637   1.59425942]
17391304
[-0.00464728 -0.00911137 -0.00177328]
[ 0.28880887 -0.15001639  0.03420693]
19412282
[-1.52171554  0.9148925   0.84763451]
[ -0.12309495 -30.29852551   7.97679411]
21340171
[-1.36539405 -0.85024314 -0.74972409]
[  3.99879834 -30.4969083    5.64596679]


In [5]:
class utils:
    @classmethod
    def check_shape(cls, tensor, shape):
        tensor_size = tensor.size()
        for i in range(len(shape)):
            assert(tensor_size[i] == shape[i])

    @classmethod
    def check_type(cls, tensor, dtype):
        assert(tensor.dtype == dtype)

    @classmethod
    def check_tensor(cls, tensor, shape=None, dtype=None):
        if shape is not None:
            cls.check_shape(tensor, shape)
        if dtype is not None:
            cls.check_type(tensor, dtype)

    @classmethod
    def get_base_imgs(cls, cam_names, base_img_dir):
        base_images = []
        for cam_name in cam_names:
            base_img_path = base_img_dir+"camera_"+cam_name+"_base_img.png"
            base_img = cv2.imread(base_img_path)
            base_img = cv2.cvtColor(base_img, cv2.COLOR_BGR2GRAY)
            base_img = base_img.astype("float32",copy=False)
            base_img /= np.max(base_img)
            base_images.append(base_img)
        return base_images

    @classmethod
    def to_numpy(cls, thing):
        if isinstance(thing, np.ndarray):
            return thing
        elif torch.is_tensor(thing):
            return thing.detach().cpu().numpy()
        elif isinstance(thing, list):
            return np.array(thing)
        else:
            raise TypeError("Please pass a list, tensor, or ndarray.")

class ProjectFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, R_vec, T_vec, K_mx, D_vec, fisheye, epsilon):
        out_obj = None
        input_clone = torch.clone(input)
        input_numpy = utils.to_numpy(input_clone)
        num_dims = input_numpy.ndim

        if num_dims == 2:
            input_numpy = input_numpy.reshape(-1,1,3)
        assert input_numpy.ndim == 3

        R_vec_numpy = utils.to_numpy(R_vec)
        T_vec_numpy = utils.to_numpy(T_vec)
        # K_mx should be numpy
        # D_vec should be numpy
        # fisheye should be bool
        # epsilon should be float

        def project_helper(fisheye, cv_inputs):
            out_obj = None
            if fisheye:
                out_obj = cv2.fisheye.projectPoints(*cv_inputs)
            else:
                out_obj = cv2.projectPoints(*cv_inputs)
            return out_obj

        cv_inputs = [input_numpy, R_vec_numpy, T_vec_numpy, K_mx, D_vec]
        out_obj = project_helper(fisheye, cv_inputs)

        out = out_obj[0].reshape(-1, 2)
        proj_jacobian = out_obj[1][:,:6]

        proj_jacobian_tensor = torch.tensor(proj_jacobian, dtype=input.dtype, device=input.device)

        grad_input_list = []

        def estimate_grad(dim, epsilon, fisheye, cv_inputs):
            cv_inputs_minus = deepcopy(cv_inputs)
            cv_inputs_plus = deepcopy(cv_inputs)
            cv_inputs_minus[0][:,:,dim] -= epsilon
            cv_inputs_plus[0][:,:,dim] += epsilon
            out_minus = torch.tensor(project_helper(fisheye, cv_inputs_minus)[0], dtype=input.dtype, device=input.device)
            out_plus = torch.tensor(project_helper(fisheye, cv_inputs_plus)[0], dtype=input.dtype, device=input.device)
            dim_grad = out_plus - out_minus
            dim_grad /= (2*epsilon)
            dim_grad = dim_grad.reshape(-1, 2, 1)
            return dim_grad

        for i in range(3):
            dim_grad = estimate_grad(i, epsilon, fisheye, cv_inputs)
            grad_input_list.append(dim_grad)

        input_jacobian = torch.cat(grad_input_list, dim=2)
        assert input_jacobian.shape[1] == 2
        assert input_jacobian.shape[2] == 3
        assert input_jacobian.dim() == 3

        ctx.save_for_backward(input_jacobian, proj_jacobian_tensor)
        out = torch.tensor(out, requires_grad=True, dtype=input.dtype, device=input.device).reshape(-1,2).squeeze()
        return out

    @staticmethod
    def backward(ctx, grad_output):
        input_jacobian, proj_jacobian_tensor = ctx.saved_tensors
        grad_input = grad_R_vec = grad_T_vec = grad_K_mx = grad_D_vec = grad_fisheye = grad_epsilon = None
        # grad_input.shape = [-1, 3]
        # grad_R_vec.shape = [3]
        grad_output_shaped = torch.clone(grad_output).view(-1, 1, 2)
        grad_input = grad_output_shaped.matmul(input_jacobian).squeeze()
        #grad_output_shaped = grad_output_shaped.view(-1, 1)
        #grad_params = proj_jacobian_tensor.t().mm(grad_output_shaped).squeeze()
        #grad_R_vec = grad_params[:3].squeeze()
        #grad_T_vec = grad_params[3:].squeeze()
        return grad_input, grad_R_vec, grad_T_vec, grad_K_mx, grad_D_vec, grad_fisheye, grad_epsilon

def project(input, R_vec=torch.torch.randn(3,dtype=torch.double), T_vec=torch.randn(3,dtype=torch.double), K_mx=np.random.rand(3,3), D_vec=np.zeros(4,dtype=np.double), fisheye=False, epsilon=1e-6):
    return ProjectFunction.apply(input, R_vec, T_vec, K_mx, D_vec, fisheye, epsilon)

from torch.autograd import gradcheck
input = (torch.randn(10,3,dtype=torch.double,requires_grad=True))
test = gradcheck(project, input, eps=1e-6, atol=1e-4)
print(test)

True


In [26]:
r = 9
arena_height = 39
# points per inch
ppi = 5
num_x_points = num_y_points = (r * 2 * ppi) + 1
num_z_points = (arena_height * ppi) + 1
# following two vals were painstakingly tested/computed but are still noisy
center_x = -1.199405162390885
center_y = -1.6960642799627483
x, y, z = np.linspace(-r,r,num_x_points), np.linspace(-r,r,num_y_points), np.linspace(0,arena_height,num_z_points)
inner_radius = 3.0
outer_radius = 6.693
xxx, yyy, zzz = np.meshgrid(x,y,z)
all_points = np.rot90(np.vstack(list(map(np.ravel, [xxx,yyy,zzz]))))
all_points_tensor = torch.tensor(all_points.copy(),dtype=torch.double,device=device)
all_image_points = []
with torch.no_grad():
    for i in range(N_cams):
        cam = cams[i]
        image_points = project(all_points_tensor,
                               torch.tensor(cam.rvec,dtype=torch.double,device=device),
                               torch.tensor(cam.tvec,dtype=torch.double,device=device),
                              K_mx=cam.K,
                                D_vec=cam.D,
                               fisheye=cam.fisheye
                              )
        all_image_points.append(image_points)



In [27]:
def bilinear_interp(image, image_points):
    # This is just the first formula from wikipedia for bilinear interpolation expressed in py-torch
    image_x = image_points[:, 0:1].squeeze()
    image_y = image_points[:, 1:2].squeeze()
    lower_bound = torch.zeros_like(image_x)
    upper_bound_y = torch.ones_like(image_x)*1023
    upper_bound_x = torch.ones_like(image_x)*1279 
    x1 = torch.floor(image_x)
    y1 = torch.floor(image_y)
    x2 = x1 + 1
    y2 = y1 + 1
    x1 = torch.max(torch.min(x1, upper_bound_x),lower_bound)
    y1 = torch.max(torch.min(y1, upper_bound_y),lower_bound)
    x2 = torch.max(torch.min(x2, upper_bound_x),lower_bound)
    y2 = torch.max(torch.min(y2, upper_bound_y),lower_bound)
    left = torch.cat(((x2-image_x).reshape(-1,1),(image_x-x1).reshape(-1,1)), axis=1).reshape(-1,1,2)
    y1_idx = y1.to(torch.int).detach().cpu().numpy().tolist()
    y2_idx = y2.to(torch.int).detach().cpu().numpy().tolist()
    x1_idx = x1.to(torch.int).detach().cpu().numpy().tolist()
    x2_idx = x2.to(torch.int).detach().cpu().numpy().tolist()
    middle_top_left = image[y1_idx,x1_idx].reshape(-1,1)
    middle_top_right = image[y2_idx,x1_idx].reshape(-1,1)
    middle_top = torch.cat((middle_top_left,middle_top_right),dim=1).reshape(-1,1,2)
    middle_bot_left = image[y1_idx,x2_idx].reshape(-1,1)
    middle_bot_right = image[y2_idx,x2_idx].reshape(-1,1)
    middle_bot = torch.cat((middle_bot_left,middle_bot_right),dim=1).reshape(-1,1,2)
    middle = torch.cat((middle_top,middle_bot),dim=1)
    right = torch.cat(((y2-image_y).reshape(-1,1),(image_y-y1).reshape(-1,1)), axis=1).reshape(-1,2,1)
    middle_right = torch.matmul(middle, right)
    middle_right = middle_right.reshape(-1,2,1)
    vals = torch.matmul(left, middle_right).reshape(-1)
    return vals

In [28]:
def idx_converter(point, range_begin, range_end, num_points):
    interval_length = (range_end - range_begin) / (num_points - 1)
    return int((point - range_begin) / interval_length)

x_indices = []
y_indices = []
z_indices = []
for point in all_points:
    x_idx = idx_converter(point[0], x[0], x[-1], num_x_points)
    y_idx = idx_converter(point[1], y[0], y[-1], num_y_points)
    z_idx = idx_converter(point[2], z[0], z[-1], num_z_points)
    x_indices.append(x_idx)
    y_indices.append(y_idx)
    z_indices.append(z_idx)

In [29]:
def generate_model_input(images):
    all_vals = []
    for i in range(N_cams):
        image_points = all_image_points[i]
        image = images[i]
        interpolated_vals = bilinear_interp(image,image_points).to(torch.double).to(device)
        all_vals.append(interpolated_vals)
    vals = torch.zeros(N_cams,len(z),len(y),len(x),dtype=torch.double, device=device)
    for i in range(N_cams):
        vals[i,z_indices,y_indices,x_indices] = all_vals[i]
    return vals

In [30]:
class TrainingData(Dataset):
    def __init__(self, transform=None, target_transform=None, num_body_parts=num_body_parts, image_shape_x=1280, image_shape_y=1024, training_data_dir=os.path.join(os.getcwd(), "Training_Data"), num_images_per_trial=4):
        trial_dirs = os.listdir(training_data_dir)
        process_dir_list(trial_dirs)
        trial_paths = list(map(lambda x: os.path.join(training_data_dir, x), trial_dirs))
        self.data = trial_paths
        self.num_images = num_images_per_trial
        self.image_shape_x = image_shape_x
        self.image_shape_y = image_shape_y
        self.num_body_parts = num_body_parts
        self.transform = transform
        self.target_transform = target_transform
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        trial_path = self.data[idx]
        trial_data = os.listdir(trial_path)
        process_dir_list(trial_data)
        images = []
        for i in range(0, self.num_images):
            image_path = os.path.join(trial_path, trial_data[i])
            img = cv2.imread(image_path)
            img_bw = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            img_bw = img_bw.astype("float64", copy=False)
            img_bw /= np.max(img_bw)
            images.append(img_bw)
        out_images = np.concatenate(images, axis=0)
        out = torch.tensor(out_images,dtype=torch.double).reshape(self.num_images,self.image_shape_y,self.image_shape_x)
        if self.transform is not None:
            out = self.transform(out)
        if len(trial_data) != 5:
            print("WARNING! Labels for training not found.")
            labels = torch.rand(self.num_images*2,self.num_body_parts,dtype=torch.double)
            return out, labels
        labels_path = os.path.join(trial_path, "labels.csv")
        labels_numpy = np.genfromtxt(labels_path, delimiter=',')
        labels = torch.tensor(labels_numpy, dtype=torch.double).reshape(self.num_images*2,self.num_body_parts)
        if self.target_transform is not None:
            labels = self.target_transform(labels)
        return out, labels

Training_Data = TrainingData(transform=generate_model_input)
data_loader = None
if not torch.cuda.is_available():
    data_loader = DataLoader(
        Training_Data,
        batch_size = 5,
        #num_workers = num_cpu_cores,
        shuffle = True
    )
else:
    data_loader = DataLoader(
        Training_Data,
        batch_size = 5,
        #num_workers = num_cpu_cores,
        shuffle = True,
        pin_memory = True
    ) 

FileNotFoundError: [Errno 2] No such file or directory: '/Users/patrickdwyer/Documents/Scwartz_Lab/Final_Project/notebooks/Training_Data'

In [31]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        convs = []
        padding = 1
        stride = 1
        poolstride1 = 3
        poolstride2 = 2
        dilation = 1
        print(num_z_points)
        print(num_y_points)
        print(num_x_points)
        final_depth = num_z_points
        final_height = num_y_points
        final_width = num_x_points
        kernel_size1 = 4
        kernel_size2 = 2
        kernel_pool1 = 3
        kernel_pool2 = 2
        for i in range(4,8):
            conv1 = nn.Conv3d(i,i,kernel_size1,stride=stride,padding=padding)
            final_depth = math.floor(((final_depth + (2*padding) - (dilation * (kernel_size1 - 1)) - 1)/stride)+1)
            final_height = math.floor(((final_height + (2*padding) - (dilation * (kernel_size1 - 1)) - 1)/stride)+1)
            final_width = math.floor(((final_width + (2*padding) - (dilation * (kernel_size1 - 1)) - 1)/stride)+1)
            pool1 = nn.MaxPool3d(kernel_pool1,stride=poolstride1,padding=padding)
            final_depth = math.floor(((final_depth + (2*padding) - (dilation * (kernel_pool1 - 1)) - 1)/poolstride1)+1)
            final_height = math.floor(((final_height + (2*padding) - (dilation * (kernel_pool1 - 1)) - 1)/poolstride1)+1)
            final_width = math.floor(((final_width + (2*padding) - (dilation * (kernel_pool1 - 1)) - 1)/poolstride1)+1)
            conv2 = nn.Conv3d(i,i+1,kernel_size2,stride=stride,padding=padding)
            final_depth = math.floor(((final_depth + (2*padding) - (dilation * (kernel_size2 - 1)) - 1)/stride)+1)
            final_height = math.floor(((final_height + (2*padding) - (dilation * (kernel_size2 - 1)) - 1)/stride)+1)
            final_width = math.floor(((final_width + (2*padding) - (dilation * (kernel_size2 - 1)) - 1)/stride)+1)
            pool2 = nn.MaxPool3d(kernel_pool2,stride=poolstride2,padding=padding)
            final_depth = math.floor(((final_depth + (2*padding) - (dilation * (kernel_pool2 - 1)) - 1)/poolstride2)+1)
            final_height = math.floor(((final_height + (2*padding) - (dilation * (kernel_pool2 - 1)) - 1)/poolstride2)+1)
            final_width = math.floor(((final_width + (2*padding) - (dilation * (kernel_pool2 - 1)) - 1)/poolstride2)+1)
            convs.append(conv1)
            convs.append(pool1)
            convs.append(conv2)
            convs.append(pool2)
        self.conv_net = nn.Sequential(*convs)
        print(final_depth)
        print(final_height)
        print(final_width)
        after_conv_num = final_depth*final_height*final_width*8
        self.after_conv_num = after_conv_num
        self.l1 = nn.Linear(after_conv_num, after_conv_num //2)
        self.l2 = nn.Linear(after_conv_num // 2, 8 * 3)
    
    def forward(self, x):
        conv_out = self.conv_net(x).view(-1, self.after_conv_num)
        lin_out = self.l2(F.relu(self.l1(conv_out)))
        out = lin_out.view(-1,8,3)
        return out

In [32]:
model = Model()
model = model.to(device)
model = model.to(torch.double)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
print(model)

196
91
91
2
2
2
Model(
  (conv_net): Sequential(
    (0): Conv3d(4, 4, kernel_size=(4, 4, 4), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): MaxPool3d(kernel_size=3, stride=3, padding=1, dilation=1, ceil_mode=False)
    (2): Conv3d(4, 5, kernel_size=(2, 2, 2), stride=(1, 1, 1), padding=(1, 1, 1))
    (3): MaxPool3d(kernel_size=2, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Conv3d(5, 5, kernel_size=(4, 4, 4), stride=(1, 1, 1), padding=(1, 1, 1))
    (5): MaxPool3d(kernel_size=3, stride=3, padding=1, dilation=1, ceil_mode=False)
    (6): Conv3d(5, 6, kernel_size=(2, 2, 2), stride=(1, 1, 1), padding=(1, 1, 1))
    (7): MaxPool3d(kernel_size=2, stride=2, padding=1, dilation=1, ceil_mode=False)
    (8): Conv3d(6, 6, kernel_size=(4, 4, 4), stride=(1, 1, 1), padding=(1, 1, 1))
    (9): MaxPool3d(kernel_size=3, stride=3, padding=1, dilation=1, ceil_mode=False)
    (10): Conv3d(6, 7, kernel_size=(2, 2, 2), stride=(1, 1, 1), padding=(1, 1, 1))
    (11): MaxPool3d(kernel_size=2, str

In [None]:
# idx1, idx2, length, weight
# idxs according to body_parts
# nose - left ear
# nose - right ear
# left ear - neck
# right ear - neck
# neck - left hip (these two probably aren't constant enough (so zero weight below))
# neck - right hip (these two probably aren't constant enough (so zero weight below))
# left hip - tail base
# right hip - tail base
edges = [
    [0,1,1,1],
    [0,2,1,1],
    [1,3,1,1],
    [2,3,1,1],
    [3,4,1,0],
    [3,5,1,0],
    [4,6,1,1],
    [5,6,1,1]
]

In [None]:
projectors = []
for i in range(N_cams):
    cam = cams[i]
    projector = lambda x: project(x, R_vec=cam.rvec, T_vec=cam.tvec, K_mx=cam.K, D_vec=cam.D, fisheye=cam.fisheye)
    projectors.append(projector)

In [None]:
def loss_fn(pred, target, edges=edges, cams=cams, projectors=projectors):
    loss = 0.
    for edge_info in edges:
        weight = edge_info[3]
        length = edge_info[2]
        dist = (weight * torch.abs(length - torch.square(pred[:,edge_info[0]:edge_info[0]+1,:].squeeze() - pred[:,edge_info[1]:edge_info[1]+1,:].squeeze()))).sum()
        loss += dist
    for i in range(N_cams):
        cam = cams[i]
        pred = pred.reshape(-1, 3)
        projector = projectors[i]
        projected = projector(pred).view(-1)
        cur_target = target[:,i*2:(i*2)+2,:].reshape(-1)
        dist = torch.square(projected - cur_target).sum()
        loss += dist
    return loss
        

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for idx, batch in enumerate(dataloader):
        batch_in = batch[0]
        target = batch[1]
        batch_in = batch_in.to(device)
        batch_in = batch_in.to(torch.double)
        pred = model(batch_in)
        loss = loss_fn(pred, target)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if idx % 100 == 0:
            loss_item, current = loss.item(), (idx + 1) * len(batch_in)
            print(f"loss: {loss_item:>7f}  [{current:>5d}/{size:>5d}]")
            gc.collect()

In [None]:
n_epochs = 1
for t in range(n_epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(data_loader, model, loss_fn, optimizer)
print("Done!")

In [None]:
torch.save(model.state_dict(), os.path.join(os.getcwd(), "trained_model.pt"))