In [30]:
import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from torch import optim
from torch import nn

from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from torchvision import transforms

import random
from glob import glob
import pandas as pd
import numpy as np
from PIL import Image

In [31]:
from torchvision.transforms.transforms import Resize
from glob import glob

import pandas as pd
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from torchvision import transforms

def extract_day(images):
    day = int(images.split('.')[-2][-2:])
    return day

def make_day_array(images):
    day_array = np.array([extract_day(x) for x in images])
    return day_array

def make_combination(length, species, data_frame, direct_name):
    before_file_path = []
    after_file_path = []
    time_delta = []

    for i in range(length):
        
        # 하위 폴더 중에서 랜덤하게 선택을 한다.
        direct = random.randrange(0,len(direct_name))
        # 위에서 결정된 폴더를 선택한다. 
        temp = data_frame[data_frame['version'] == direct_name[direct]]
    
        # 밑은 기존의 코드와 동일합니다.
        sample = temp[temp['species'] == species].sample(2)
        after = sample[sample['day'] == max(sample['day'])].reset_index(drop=True)
        before = sample[sample['day'] == min(sample['day'])].reset_index(drop=True)

        before_file_path.append(before.iloc[0]['file_name'])
        after_file_path.append(after.iloc[0]['file_name'])
        delta = int(after.iloc[0]['day'] - before.iloc[0]['day'])
        time_delta.append(delta)

    combination_df = pd.DataFrame({
        'before_file_path': before_file_path,
        'after_file_path': after_file_path,
        'time_delta': time_delta,
    })

    combination_df['species'] = species

    return combination_df

class TrainDataset(Dataset):
    def __init__(self, combination_df, is_test=None):
        self.combination_df = combination_df
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomCrop(224),                                    
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomAffine((-20, 20)),
            transforms.RandomRotation(90),                                 
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        self.is_test = is_test

    def __getitem__(self, idx):
        before_image = Image.open(self.combination_df.iloc[idx]['before_file_path'])
        after_image = Image.open(self.combination_df.iloc[idx]['after_file_path'])

        before_image = self.transform(before_image)
        after_image = self.transform(after_image)
        if self.is_test:
            return before_image, after_image
        time_delta = self.combination_df.iloc[idx]['time_delta']
        return before_image, after_image, time_delta

    def __len__(self):
        return len(self.combination_df)

class TestDataset(Dataset):
    def __init__(self, combination_df, is_test=None):
        self.combination_df = combination_df
        self.transform = transforms.Compose([                                                             
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        self.is_test = is_test

    def __getitem__(self, idx):
        before_image = Image.open(self.combination_df.iloc[idx]['before_file_path'])
        after_image = Image.open(self.combination_df.iloc[idx]['after_file_path'])

        before_image = self.transform(before_image)
        after_image = self.transform(after_image)
        if self.is_test:
            return before_image, after_image
        time_delta = self.combination_df.iloc[idx]['time_delta']
        return before_image, after_image, time_delta

    def __len__(self):
        return len(self.combination_df)

In [32]:
import torch
from torch import nn
from torchvision.models import efficientnet_b2


class CompareCNN(nn.Module):

    def __init__(self):
        super(CompareCNN, self).__init__()
        self.effnet = efficientnet_b2(pretrained=True)
        self.fc_layer = nn.Linear(1000, 1) 

    def forward(self, input):
        x = self.effnet(input)
        output = self.fc_layer(x)
        return output



class CompareNet(nn.Module):

    def __init__(self):
        super(CompareNet, self).__init__()
        self.before_net = CompareCNN()
        self.after_net = CompareCNN()

    def forward(self, before_input, after_input):
        before = self.before_net(before_input)
        after = self.after_net(after_input)
        delta = before - after
        return delta

In [33]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [34]:
def seed_everything(seed): # seed 고정
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)


seed_everything(2048)

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
lr = 1e-5
epochs = 20
batch_size = 32
valid_batch_size = 50

model = CompareNet().to(device)

# 학습 데이터가 있는 폴더 위치
root_path = './drive/MyDrive/open_224/train_dataset/'

# BC 폴더와 LT 폴더에 있는 하위 폴더를 저장한다.
bc_direct = glob(root_path + '/BC/*')
bc_direct_name = [x[-5:] for x in bc_direct]
lt_direct = glob(root_path + '/LT/*')
lt_direct_name = [x[-5:] for x in lt_direct]

# 하위 폴더에 있는 이미지들을 하위 폴더 이름과 매칭시켜서 저장한다.
bc_images = {key : glob(name + '/*.png') for key,name in zip(bc_direct_name, bc_direct)}
lt_images = {key : glob(name + '/*.png') for key,name in zip(lt_direct_name, lt_direct)}

# 하위 폴더에 있는 이미지들에서 날짜 정보만 따로 저장한다.
bc_dayes = {key : make_day_array(bc_images[key]) for key in bc_direct_name}
lt_dayes = {key : make_day_array(lt_images[key]) for key in lt_direct_name}

bc_dfs = []
for i in bc_direct_name:
    bc_df = pd.DataFrame({
        'file_name':bc_images[i],
        'day':bc_dayes[i],
        'species':'bc',
        'version':i
    })
    bc_dfs.append(bc_df)
    
lt_dfs = []
for i in lt_direct_name:
    lt_df = pd.DataFrame({
        'file_name':lt_images[i],
        'day':lt_dayes[i],
        'species':'lt',
        'version':i
    })
    lt_dfs.append(lt_df)

bc_dataframe = pd.concat(bc_dfs).reset_index(drop=True)
lt_dataframe = pd.concat(lt_dfs).reset_index(drop=True)
total_dataframe = pd.concat([bc_dataframe, lt_dataframe]).reset_index(drop=True)

bc_combination = make_combination(5000, 'bc', total_dataframe, bc_direct_name)
lt_combination = make_combination(5000, 'lt', total_dataframe, lt_direct_name)

bc_train = bc_combination.iloc[:4500]
bc_valid = bc_combination.iloc[4500:]

lt_train = lt_combination.iloc[:4500]
lt_valid = lt_combination.iloc[4500:]

train_set = pd.concat([bc_train, lt_train])
valid_set = pd.concat([bc_valid, lt_valid])



train_dataset = TrainDataset(train_set)
valid_dataset = TestDataset(valid_set)

optimizer = optim.Adam(model.parameters(), lr=lr)

train_data_loader = DataLoader(train_dataset,
                               batch_size=batch_size,
                               shuffle=True)

valid_data_loader = DataLoader(valid_dataset,
                               batch_size=valid_batch_size)

Downloading: "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b2_rwightman-bcdf34b7.pth


  0%|          | 0.00/35.2M [00:00<?, ?B/s]

In [35]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [36]:
for epoch in tqdm(range(epochs)):
    for step, (before_image, after_image, time_delta) in tqdm(enumerate(train_data_loader)):
        before_image = before_image.to(device)
        after_image = after_image.to(device)
        time_delta = time_delta.to(device)

        optimizer.zero_grad()
        logit = model(before_image, after_image)
        train_loss = (torch.sum(torch.abs(logit.squeeze(1).float() - time_delta.float())) /
                      torch.LongTensor([batch_size]).squeeze(0).to(device))
        train_loss.backward()
        optimizer.step()

        if step % 15 == 0:
            print('\n=====================loss=======================')
            print(f'\n=====================EPOCH: {epoch}=======================')
            print(f'\n=====================step: {step}=======================')
            print('MAE_loss : ', train_loss.detach().cpu().numpy())

    valid_losses = []
    with torch.no_grad():
        for valid_before, valid_after, time_delta in tqdm(valid_data_loader):
            valid_before = valid_before.to(device)
            valid_after = valid_after.to(device)
            valid_time_delta = time_delta.to(device)


            logit = model(valid_before, valid_after)
            valid_loss = (torch.sum(torch.abs(logit.squeeze(1).float() - valid_time_delta.float())) /
                          torch.LongTensor([valid_batch_size]).squeeze(0).to(device))
            valid_losses.append(valid_loss.detach().cpu())


    print(f'VALIDATION_LOSS MAE : {sum(valid_losses)/len(valid_losses)}')
    checkpoint = {
        'model': model.state_dict(),

    }

    torch.save(checkpoint, 'effnet_b2.pt')

  0%|          | 0/20 [00:00<?, ?it/s]

0it [00:00, ?it/s]




MAE_loss :  10.546714



MAE_loss :  8.278151



MAE_loss :  10.65296



MAE_loss :  9.590607



MAE_loss :  7.8610067



MAE_loss :  8.305878



MAE_loss :  8.335899



MAE_loss :  6.419019



MAE_loss :  7.5992823



MAE_loss :  8.318463



MAE_loss :  6.83419



MAE_loss :  4.783512



MAE_loss :  5.6311145



MAE_loss :  3.8978713



MAE_loss :  4.1336164



MAE_loss :  5.155183



MAE_loss :  4.4884844



MAE_loss :  5.295532



MAE_loss :  3.1586375


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 4.621885776519775


0it [00:00, ?it/s]




MAE_loss :  5.0500445



MAE_loss :  5.5941763



MAE_loss :  4.961217



MAE_loss :  4.1537457



MAE_loss :  3.9612966



MAE_loss :  4.137126



MAE_loss :  4.9423027



MAE_loss :  3.9647079



MAE_loss :  4.4979267



MAE_loss :  3.140445



MAE_loss :  4.065077



MAE_loss :  2.9221148



MAE_loss :  3.127952



MAE_loss :  3.9454057



MAE_loss :  3.5515385



MAE_loss :  4.0271354



MAE_loss :  2.9070034



MAE_loss :  4.174686



MAE_loss :  3.2888708


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 4.1045989990234375


0it [00:00, ?it/s]




MAE_loss :  3.4426003



MAE_loss :  3.214846



MAE_loss :  3.8034272



MAE_loss :  3.1965556



MAE_loss :  3.4599626



MAE_loss :  3.938099



MAE_loss :  4.070279



MAE_loss :  2.9193852



MAE_loss :  3.0308013



MAE_loss :  3.1683505



MAE_loss :  3.505301



MAE_loss :  2.7175763



MAE_loss :  3.2632103



MAE_loss :  2.991326



MAE_loss :  2.91328



MAE_loss :  2.5343878



MAE_loss :  3.0550709



MAE_loss :  4.3699064



MAE_loss :  2.9146702


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 3.8164780139923096


0it [00:00, ?it/s]




MAE_loss :  3.5748086



MAE_loss :  2.8483005



MAE_loss :  2.4806874



MAE_loss :  3.9473033



MAE_loss :  2.4758577



MAE_loss :  2.6515913



MAE_loss :  3.6983593



MAE_loss :  2.6334064



MAE_loss :  2.836486



MAE_loss :  2.3681183



MAE_loss :  3.1538224



MAE_loss :  3.1975255



MAE_loss :  2.7675161



MAE_loss :  3.9132247



MAE_loss :  2.7758563



MAE_loss :  3.1970892



MAE_loss :  2.1601362



MAE_loss :  3.0390704



MAE_loss :  2.3701267


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 3.5894927978515625


0it [00:00, ?it/s]




MAE_loss :  2.4733257



MAE_loss :  3.0303578



MAE_loss :  3.3323288



MAE_loss :  2.8755145



MAE_loss :  3.0757716



MAE_loss :  2.5850508



MAE_loss :  3.3967917



MAE_loss :  2.3743486



MAE_loss :  2.2708836



MAE_loss :  3.4570725



MAE_loss :  3.414105



MAE_loss :  1.9057554



MAE_loss :  3.5629854



MAE_loss :  2.715897



MAE_loss :  2.7744591



MAE_loss :  3.218812



MAE_loss :  2.4649553



MAE_loss :  3.679713



MAE_loss :  2.3428164


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 3.561150074005127


0it [00:00, ?it/s]




MAE_loss :  2.519918



MAE_loss :  2.4038212



MAE_loss :  3.3763123



MAE_loss :  2.030365



MAE_loss :  2.0827122



MAE_loss :  2.165896



MAE_loss :  2.2742016



MAE_loss :  2.5992012



MAE_loss :  2.5799656



MAE_loss :  2.0623055



MAE_loss :  3.1094918



MAE_loss :  2.6973958



MAE_loss :  3.0326731



MAE_loss :  4.3556004



MAE_loss :  2.7415895



MAE_loss :  1.8537961



MAE_loss :  3.302774



MAE_loss :  2.1189177



MAE_loss :  2.4511833


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 3.3017144203186035


0it [00:00, ?it/s]




MAE_loss :  2.6845326



MAE_loss :  3.0388794



MAE_loss :  2.2365644



MAE_loss :  2.4587145



MAE_loss :  2.1765819



MAE_loss :  3.0309556



MAE_loss :  1.8730221



MAE_loss :  3.672885



MAE_loss :  3.4199476



MAE_loss :  2.8791866



MAE_loss :  2.5667336



MAE_loss :  1.833025



MAE_loss :  2.1946023



MAE_loss :  2.2763844



MAE_loss :  2.169115



MAE_loss :  1.6774828



MAE_loss :  2.53743



MAE_loss :  3.316618



MAE_loss :  2.8272972


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 3.283508777618408


0it [00:00, ?it/s]




MAE_loss :  2.661403



MAE_loss :  1.9442322



MAE_loss :  2.1206203



MAE_loss :  2.5412192



MAE_loss :  2.6967564



MAE_loss :  3.058722



MAE_loss :  3.054167



MAE_loss :  2.0916264



MAE_loss :  3.9476957



MAE_loss :  1.9496727



MAE_loss :  2.2019444



MAE_loss :  4.10901



MAE_loss :  2.6160088



MAE_loss :  2.3101134



MAE_loss :  1.7850382



MAE_loss :  1.8701003



MAE_loss :  1.9331784



MAE_loss :  2.4183936



MAE_loss :  3.0803156


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 3.105375289916992


0it [00:00, ?it/s]




MAE_loss :  1.8337612



MAE_loss :  1.297785



MAE_loss :  1.8910514



MAE_loss :  2.0568426



MAE_loss :  1.876241



MAE_loss :  2.6423354



MAE_loss :  3.9293098



MAE_loss :  2.4475224



MAE_loss :  2.1818514



MAE_loss :  1.6622438



MAE_loss :  2.3170414



MAE_loss :  2.152748



MAE_loss :  1.8744836



MAE_loss :  3.914719



MAE_loss :  2.2964878



MAE_loss :  2.8816793



MAE_loss :  2.5220013



MAE_loss :  1.5302644



MAE_loss :  2.65912


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 3.1415791511535645


0it [00:00, ?it/s]




MAE_loss :  1.592139



MAE_loss :  1.6901081



MAE_loss :  2.4759166



MAE_loss :  3.7146468



MAE_loss :  1.742183



MAE_loss :  1.9224849



MAE_loss :  2.083118



MAE_loss :  1.6394955



MAE_loss :  2.3172655



MAE_loss :  2.2043295



MAE_loss :  2.674057



MAE_loss :  3.2032018



MAE_loss :  2.1893659



MAE_loss :  2.2315571



MAE_loss :  2.1596642



MAE_loss :  2.5522804



MAE_loss :  2.3626456



MAE_loss :  2.2457676



MAE_loss :  2.0493565


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 3.149613618850708


0it [00:00, ?it/s]




MAE_loss :  2.8907127



MAE_loss :  1.5903649



MAE_loss :  3.065915



MAE_loss :  2.0399623



MAE_loss :  2.478899



MAE_loss :  2.0662622



MAE_loss :  2.5960345



MAE_loss :  1.4687436



MAE_loss :  1.9995108



MAE_loss :  1.8161536



MAE_loss :  2.450087



MAE_loss :  3.2628405



MAE_loss :  2.6089125



MAE_loss :  1.9664396



MAE_loss :  1.9595026



MAE_loss :  2.32656



MAE_loss :  2.5551448



MAE_loss :  2.2004082



MAE_loss :  2.0264745


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 3.015483856201172


0it [00:00, ?it/s]




MAE_loss :  2.4757586



MAE_loss :  2.012927



MAE_loss :  2.73311



MAE_loss :  2.4001932



MAE_loss :  1.5938928



MAE_loss :  2.0571713



MAE_loss :  1.8082805



MAE_loss :  2.2402759



MAE_loss :  1.5527318



MAE_loss :  1.9791921



MAE_loss :  1.7714013



MAE_loss :  3.891249



MAE_loss :  1.8100879



MAE_loss :  1.5495365



MAE_loss :  2.300773



MAE_loss :  1.5372713



MAE_loss :  2.0047662



MAE_loss :  2.5336864



MAE_loss :  2.0066931


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 2.9489998817443848


0it [00:00, ?it/s]




MAE_loss :  1.6241066



MAE_loss :  3.3458264



MAE_loss :  2.644607



MAE_loss :  1.9199805



MAE_loss :  2.1425078



MAE_loss :  2.490365



MAE_loss :  1.2274117



MAE_loss :  3.6946633



MAE_loss :  1.3512892



MAE_loss :  1.9872717



MAE_loss :  2.6363106



MAE_loss :  1.639283



MAE_loss :  2.5004313



MAE_loss :  1.9976084



MAE_loss :  1.7792451



MAE_loss :  2.0975113



MAE_loss :  1.3206629



MAE_loss :  1.7233591



MAE_loss :  1.6339161


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 3.023333787918091


0it [00:00, ?it/s]




MAE_loss :  2.0462997



MAE_loss :  1.6344445



MAE_loss :  1.690448



MAE_loss :  1.9507276



MAE_loss :  1.8629344



MAE_loss :  1.6393371



MAE_loss :  1.5206442



MAE_loss :  1.250621



MAE_loss :  1.9370351



MAE_loss :  1.8717127



MAE_loss :  2.5686705



MAE_loss :  1.6908549



MAE_loss :  1.9496069



MAE_loss :  1.9978724



MAE_loss :  1.9191265



MAE_loss :  1.9648218



MAE_loss :  2.846488



MAE_loss :  1.698123



MAE_loss :  1.7769958


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 2.8903844356536865


0it [00:00, ?it/s]




MAE_loss :  1.5491631



MAE_loss :  1.6079249



MAE_loss :  1.3385274



MAE_loss :  2.3319602



MAE_loss :  2.211188



MAE_loss :  2.4666886



MAE_loss :  1.8689346



MAE_loss :  2.6253328



MAE_loss :  1.9575595



MAE_loss :  3.1830845



MAE_loss :  1.6363577



MAE_loss :  1.6162268



MAE_loss :  1.4995391



MAE_loss :  1.7346207



MAE_loss :  1.7424976



MAE_loss :  1.795449



MAE_loss :  1.7352581



MAE_loss :  2.1994362



MAE_loss :  1.402662


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 2.8381171226501465


0it [00:00, ?it/s]




MAE_loss :  2.3163104



MAE_loss :  2.2172947



MAE_loss :  3.001358



MAE_loss :  1.2707505



MAE_loss :  2.0051246



MAE_loss :  1.5457504



MAE_loss :  2.3997436



MAE_loss :  2.0639248



MAE_loss :  1.8169985



MAE_loss :  2.1924353



MAE_loss :  1.9479251



MAE_loss :  1.8958955



MAE_loss :  1.550856



MAE_loss :  2.7870626



MAE_loss :  1.5734999



MAE_loss :  2.2908213



MAE_loss :  1.6923885



MAE_loss :  1.6687722



MAE_loss :  1.8495603


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 2.9073069095611572


0it [00:00, ?it/s]




MAE_loss :  1.2296039



MAE_loss :  2.0404685



MAE_loss :  1.2088523



MAE_loss :  1.732096



MAE_loss :  1.7023858



MAE_loss :  2.9281802



MAE_loss :  1.8734071



MAE_loss :  3.5902724



MAE_loss :  2.1095781



MAE_loss :  1.3579568



MAE_loss :  1.2040863



MAE_loss :  2.2154763



MAE_loss :  2.0896811



MAE_loss :  2.3038359



MAE_loss :  2.3107548



MAE_loss :  2.302017



MAE_loss :  1.5189908



MAE_loss :  1.5946628



MAE_loss :  2.5107813


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 2.839986801147461


0it [00:00, ?it/s]




MAE_loss :  1.5530834



MAE_loss :  1.9847538



MAE_loss :  1.980579



MAE_loss :  1.9837668



MAE_loss :  1.6313725



MAE_loss :  1.2605139



MAE_loss :  1.8788447



MAE_loss :  1.9611413



MAE_loss :  2.0618925



MAE_loss :  1.9023459



MAE_loss :  1.7058496



MAE_loss :  1.9484454



MAE_loss :  1.4053856



MAE_loss :  1.7584088



MAE_loss :  1.833192



MAE_loss :  1.5627439



MAE_loss :  1.4269247



MAE_loss :  2.5493414



MAE_loss :  1.3720312


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 2.7191812992095947


0it [00:00, ?it/s]




MAE_loss :  2.0079932



MAE_loss :  1.7071103



MAE_loss :  1.7259929



MAE_loss :  1.7181375



MAE_loss :  1.9954342



MAE_loss :  1.7470163



MAE_loss :  1.8975089



MAE_loss :  2.6066306



MAE_loss :  1.291353



MAE_loss :  1.7790912



MAE_loss :  2.0573509



MAE_loss :  1.4613159



MAE_loss :  1.4242741



MAE_loss :  1.742024



MAE_loss :  1.6479392



MAE_loss :  1.6080406



MAE_loss :  1.5464842



MAE_loss :  1.8835436



MAE_loss :  1.6573942


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 2.7849459648132324


0it [00:00, ?it/s]




MAE_loss :  1.6970265



MAE_loss :  1.6553632



MAE_loss :  1.7527952



MAE_loss :  1.2433348



MAE_loss :  1.4473858



MAE_loss :  1.1499503



MAE_loss :  1.5431489



MAE_loss :  1.4667459



MAE_loss :  1.5043561



MAE_loss :  2.0175471



MAE_loss :  1.8546735



MAE_loss :  1.8378328



MAE_loss :  1.7889291



MAE_loss :  2.4233186



MAE_loss :  1.9008974



MAE_loss :  1.9584389



MAE_loss :  2.4918177



MAE_loss :  1.6072717



MAE_loss :  1.4317479


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 2.7833213806152344


In [37]:
test_set = pd.read_csv('./drive/MyDrive/open_224/test_dataset/test_data.csv')
test_set['l_root'] = test_set['before_file_path'].map(lambda x: './drive/MyDrive/open_224/test_dataset/' + x.split('_')[1] + '/' + x.split('_')[2])
test_set['r_root'] = test_set['after_file_path'].map(lambda x: './drive/MyDrive/open_224/test_dataset/' + x.split('_')[1] + '/' + x.split('_')[2])
test_set['before_file_path'] = test_set['l_root'] + '/' + test_set['before_file_path'] + '.png'
test_set['after_file_path'] = test_set['r_root'] + '/' + test_set['after_file_path'] + '.png'

test_dataset = TestDataset(test_set, is_test=True)
test_data_loader = DataLoader(test_dataset,
                               batch_size=64)

In [38]:
test_value = []
with torch.no_grad():
    for test_before, test_after in tqdm(test_data_loader):
        test_before = test_before.to(device)
        test_after = test_after.to(device)
        logit = model(test_before, test_after)
        value = logit.squeeze(1).detach().cpu().float()
        
        test_value.extend(value)

  0%|          | 0/62 [00:00<?, ?it/s]

In [39]:
# submission 형식을 불러온다.
submission = pd.read_csv('./drive/MyDrive/open_224/sample_submission.csv')

# 예측한 값들은 텐서 형태로 변환 시켜준다.
predict = torch.FloatTensor(test_value)

# 음수의 값을 갖는 모든 값들을 1 Day 차이가 발생하도록 바꿔줌
temp_predict = predict.numpy()
temp_predict[np.where(temp_predict<1)] = 1

# 모델의 예측 값을 저장함
submission['time_delta'] = temp_predict
submission.to_csv('effnet_b2.csv', index=False)