In [1]:
import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from torch import optim
from torch import nn

from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from torchvision import transforms

import random
from glob import glob
import pandas as pd
import numpy as np
from PIL import Image

In [2]:
from torchvision.transforms.transforms import Resize
from glob import glob

import pandas as pd
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from torchvision import transforms

def extract_day(images):
    day = int(images.split('.')[-2][-2:])
    return day

def make_day_array(images):
    day_array = np.array([extract_day(x) for x in images])
    return day_array

def make_combination(length, species, data_frame, direct_name):
    before_file_path = []
    after_file_path = []
    time_delta = []

    for i in range(length):
        
        # 하위 폴더 중에서 랜덤하게 선택을 한다.
        direct = random.randrange(0,len(direct_name))
        # 위에서 결정된 폴더를 선택한다. 
        temp = data_frame[data_frame['version'] == direct_name[direct]]
    
        # 밑은 기존의 코드와 동일합니다.
        sample = temp[temp['species'] == species].sample(2)
        after = sample[sample['day'] == max(sample['day'])].reset_index(drop=True)
        before = sample[sample['day'] == min(sample['day'])].reset_index(drop=True)

        before_file_path.append(before.iloc[0]['file_name'])
        after_file_path.append(after.iloc[0]['file_name'])
        delta = int(after.iloc[0]['day'] - before.iloc[0]['day'])
        time_delta.append(delta)

    combination_df = pd.DataFrame({
        'before_file_path': before_file_path,
        'after_file_path': after_file_path,
        'time_delta': time_delta,
    })

    combination_df['species'] = species

    return combination_df

class TrainDataset(Dataset):
    def __init__(self, combination_df, is_test=None):
        self.combination_df = combination_df
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomCrop(224),                                    
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomAffine((-20, 20)),
            transforms.RandomRotation(90),                                 
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        self.is_test = is_test

    def __getitem__(self, idx):
        before_image = Image.open(self.combination_df.iloc[idx]['before_file_path'])
        after_image = Image.open(self.combination_df.iloc[idx]['after_file_path'])

        before_image = self.transform(before_image)
        after_image = self.transform(after_image)
        if self.is_test:
            return before_image, after_image
        time_delta = self.combination_df.iloc[idx]['time_delta']
        return before_image, after_image, time_delta

    def __len__(self):
        return len(self.combination_df)

class TestDataset(Dataset):
    def __init__(self, combination_df, is_test=None):
        self.combination_df = combination_df
        self.transform = transforms.Compose([                                                             
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        self.is_test = is_test

    def __getitem__(self, idx):
        before_image = Image.open(self.combination_df.iloc[idx]['before_file_path'])
        after_image = Image.open(self.combination_df.iloc[idx]['after_file_path'])

        before_image = self.transform(before_image)
        after_image = self.transform(after_image)
        if self.is_test:
            return before_image, after_image
        time_delta = self.combination_df.iloc[idx]['time_delta']
        return before_image, after_image, time_delta

    def __len__(self):
        return len(self.combination_df)

In [3]:
import torch
from torch import nn
from torchvision.models import resnet50


class CompareCNN(nn.Module):

    def __init__(self):
        super(CompareCNN, self).__init__()
        self.resnet = resnet50(pretrained=True)
        self.fc_layer = nn.Linear(1000, 1) 

    def forward(self, input):
        x = self.resnet(input)
        output = self.fc_layer(x)
        return output



class CompareNet(nn.Module):

    def __init__(self):
        super(CompareNet, self).__init__()
        self.before_net = CompareCNN()
        self.after_net = CompareCNN()

    def forward(self, before_input, after_input):
        before = self.before_net(before_input)
        after = self.after_net(after_input)
        delta = before - after
        return delta

In [4]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [5]:
def seed_everything(seed): # seed 고정
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)


seed_everything(2048)

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
lr = 1e-5
epochs = 20
batch_size = 32
valid_batch_size = 50

model = CompareNet().to(device)

# 학습 데이터가 있는 폴더 위치
root_path = './drive/MyDrive/open_224/train_dataset/'

# BC 폴더와 LT 폴더에 있는 하위 폴더를 저장한다.
bc_direct = glob(root_path + '/BC/*')
bc_direct_name = [x[-5:] for x in bc_direct]
lt_direct = glob(root_path + '/LT/*')
lt_direct_name = [x[-5:] for x in lt_direct]

# 하위 폴더에 있는 이미지들을 하위 폴더 이름과 매칭시켜서 저장한다.
bc_images = {key : glob(name + '/*.png') for key,name in zip(bc_direct_name, bc_direct)}
lt_images = {key : glob(name + '/*.png') for key,name in zip(lt_direct_name, lt_direct)}

# 하위 폴더에 있는 이미지들에서 날짜 정보만 따로 저장한다.
bc_dayes = {key : make_day_array(bc_images[key]) for key in bc_direct_name}
lt_dayes = {key : make_day_array(lt_images[key]) for key in lt_direct_name}

bc_dfs = []
for i in bc_direct_name:
    bc_df = pd.DataFrame({
        'file_name':bc_images[i],
        'day':bc_dayes[i],
        'species':'bc',
        'version':i
    })
    bc_dfs.append(bc_df)
    
lt_dfs = []
for i in lt_direct_name:
    lt_df = pd.DataFrame({
        'file_name':lt_images[i],
        'day':lt_dayes[i],
        'species':'lt',
        'version':i
    })
    lt_dfs.append(lt_df)

bc_dataframe = pd.concat(bc_dfs).reset_index(drop=True)
lt_dataframe = pd.concat(lt_dfs).reset_index(drop=True)
total_dataframe = pd.concat([bc_dataframe, lt_dataframe]).reset_index(drop=True)

bc_combination = make_combination(5000, 'bc', total_dataframe, bc_direct_name)
lt_combination = make_combination(5000, 'lt', total_dataframe, lt_direct_name)

bc_train = bc_combination.iloc[:4500]
bc_valid = bc_combination.iloc[4500:]

lt_train = lt_combination.iloc[:4500]
lt_valid = lt_combination.iloc[4500:]

train_set = pd.concat([bc_train, lt_train])
valid_set = pd.concat([bc_valid, lt_valid])



train_dataset = TrainDataset(train_set)
valid_dataset = TestDataset(valid_set)

optimizer = optim.Adam(model.parameters(), lr=lr)

train_data_loader = DataLoader(train_dataset,
                               batch_size=batch_size,
                               shuffle=True)

valid_data_loader = DataLoader(valid_dataset,
                               batch_size=valid_batch_size)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

In [6]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [7]:
for epoch in tqdm(range(epochs)):
    for step, (before_image, after_image, time_delta) in tqdm(enumerate(train_data_loader)):
        before_image = before_image.to(device)
        after_image = after_image.to(device)
        time_delta = time_delta.to(device)

        optimizer.zero_grad()
        logit = model(before_image, after_image)
        train_loss = (torch.sum(torch.abs(logit.squeeze(1).float() - time_delta.float())) /
                      torch.LongTensor([batch_size]).squeeze(0).to(device))
        train_loss.backward()
        optimizer.step()

        if step % 15 == 0:
            print('\n=====================loss=======================')
            print(f'\n=====================EPOCH: {epoch}=======================')
            print(f'\n=====================step: {step}=======================')
            print('MAE_loss : ', train_loss.detach().cpu().numpy())

    valid_losses = []
    with torch.no_grad():
        for valid_before, valid_after, time_delta in tqdm(valid_data_loader):
            valid_before = valid_before.to(device)
            valid_after = valid_after.to(device)
            valid_time_delta = time_delta.to(device)


            logit = model(valid_before, valid_after)
            valid_loss = (torch.sum(torch.abs(logit.squeeze(1).float() - valid_time_delta.float())) /
                          torch.LongTensor([valid_batch_size]).squeeze(0).to(device))
            valid_losses.append(valid_loss.detach().cpu())


    print(f'VALIDATION_LOSS MAE : {sum(valid_losses)/len(valid_losses)}')
    checkpoint = {
        'model': model.state_dict(),

    }

    torch.save(checkpoint, 'resnet50_v3.pt')

  0%|          | 0/20 [00:00<?, ?it/s]

0it [00:00, ?it/s]




MAE_loss :  12.750802



MAE_loss :  9.192764



MAE_loss :  5.4709964



MAE_loss :  4.975234



MAE_loss :  5.487015



MAE_loss :  2.6636267



MAE_loss :  3.546033



MAE_loss :  3.6532307



MAE_loss :  2.4682498



MAE_loss :  3.4034638



MAE_loss :  2.90894



MAE_loss :  3.0827339



MAE_loss :  2.4715023



MAE_loss :  2.1530125



MAE_loss :  1.8898952



MAE_loss :  2.4473672



MAE_loss :  2.2457402



MAE_loss :  1.5961773



MAE_loss :  2.0862608


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 2.7215826511383057


0it [00:00, ?it/s]




MAE_loss :  3.2068024



MAE_loss :  2.0603557



MAE_loss :  2.0398111



MAE_loss :  2.8899565



MAE_loss :  1.5001197



MAE_loss :  1.4803559



MAE_loss :  2.2545419



MAE_loss :  2.5510488



MAE_loss :  2.0888896



MAE_loss :  1.3650467



MAE_loss :  2.188926



MAE_loss :  2.98002



MAE_loss :  1.5298302



MAE_loss :  1.8165026



MAE_loss :  2.449702



MAE_loss :  2.2068775



MAE_loss :  3.400228



MAE_loss :  1.6610348



MAE_loss :  1.8984163


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 2.539924144744873


0it [00:00, ?it/s]




MAE_loss :  3.1312368



MAE_loss :  2.4615388



MAE_loss :  2.1852145



MAE_loss :  1.4447902



MAE_loss :  1.9865927



MAE_loss :  1.8747216



MAE_loss :  1.9240724



MAE_loss :  2.0004017



MAE_loss :  2.24544



MAE_loss :  2.6116028



MAE_loss :  2.37829



MAE_loss :  1.9382466



MAE_loss :  2.4415376



MAE_loss :  1.5043823



MAE_loss :  1.6402438



MAE_loss :  3.6164815



MAE_loss :  1.8450122



MAE_loss :  1.7602719



MAE_loss :  2.277209


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 2.24080753326416


0it [00:00, ?it/s]




MAE_loss :  1.4063561



MAE_loss :  1.746099



MAE_loss :  1.6904279



MAE_loss :  2.7566524



MAE_loss :  1.3740863



MAE_loss :  0.9802615



MAE_loss :  1.4378183



MAE_loss :  2.4098206



MAE_loss :  3.560207



MAE_loss :  1.7250655



MAE_loss :  1.6215932



MAE_loss :  2.1110096



MAE_loss :  1.655453



MAE_loss :  1.830681



MAE_loss :  1.5268893



MAE_loss :  2.5132444



MAE_loss :  1.8002634



MAE_loss :  1.7718784



MAE_loss :  1.9957172


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 2.3185477256774902


0it [00:00, ?it/s]




MAE_loss :  1.2579532



MAE_loss :  2.1170044



MAE_loss :  1.6183201



MAE_loss :  1.3713888



MAE_loss :  0.88478345



MAE_loss :  1.3774189



MAE_loss :  1.535612



MAE_loss :  3.2608857



MAE_loss :  1.4408791



MAE_loss :  2.3090777



MAE_loss :  1.9462011



MAE_loss :  1.8114507



MAE_loss :  2.0416317



MAE_loss :  1.740463



MAE_loss :  2.8932176



MAE_loss :  1.296229



MAE_loss :  1.1863396



MAE_loss :  1.4057186



MAE_loss :  3.0292096


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 2.156770706176758


0it [00:00, ?it/s]




MAE_loss :  2.0844557



MAE_loss :  2.607998



MAE_loss :  1.4659089



MAE_loss :  0.74889493



MAE_loss :  1.3191689



MAE_loss :  1.2341247



MAE_loss :  1.5371025



MAE_loss :  1.2434416



MAE_loss :  2.2245517



MAE_loss :  1.9194753



MAE_loss :  2.5254598



MAE_loss :  1.8453729



MAE_loss :  1.8544968



MAE_loss :  1.1039283



MAE_loss :  1.7106707



MAE_loss :  1.680567



MAE_loss :  3.5509415



MAE_loss :  2.088635



MAE_loss :  1.6337636


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 1.9605789184570312


0it [00:00, ?it/s]




MAE_loss :  1.0429683



MAE_loss :  1.3808079



MAE_loss :  1.1285472



MAE_loss :  1.168644



MAE_loss :  2.4332087



MAE_loss :  2.1722622



MAE_loss :  1.7059069



MAE_loss :  2.0191734



MAE_loss :  1.4556539



MAE_loss :  1.4922326



MAE_loss :  1.5820282



MAE_loss :  1.3755434



MAE_loss :  1.4617826



MAE_loss :  1.5311443



MAE_loss :  1.546529



MAE_loss :  1.7435257



MAE_loss :  4.8939123



MAE_loss :  2.0831544



MAE_loss :  3.3838363


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 1.9392168521881104


0it [00:00, ?it/s]




MAE_loss :  1.549355



MAE_loss :  1.9760587



MAE_loss :  1.056845



MAE_loss :  1.4805593



MAE_loss :  1.0522213



MAE_loss :  1.851051



MAE_loss :  2.154557



MAE_loss :  1.8590994



MAE_loss :  1.4959695



MAE_loss :  1.1875772



MAE_loss :  1.2332513



MAE_loss :  2.0135388



MAE_loss :  2.2119856



MAE_loss :  1.5722501



MAE_loss :  1.3257252



MAE_loss :  1.2205329



MAE_loss :  1.6139456



MAE_loss :  1.0148681



MAE_loss :  1.4093854


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 1.8712055683135986


0it [00:00, ?it/s]




MAE_loss :  1.7216915



MAE_loss :  1.1741576



MAE_loss :  0.8750288



MAE_loss :  1.7855291



MAE_loss :  2.8976383



MAE_loss :  1.2712028



MAE_loss :  1.0744498



MAE_loss :  2.3461022



MAE_loss :  1.3751884



MAE_loss :  0.97508454



MAE_loss :  3.2492986



MAE_loss :  2.0420275



MAE_loss :  1.2460119



MAE_loss :  1.2456963



MAE_loss :  1.2301904



MAE_loss :  1.1190878



MAE_loss :  1.0179455



MAE_loss :  1.5232396



MAE_loss :  2.4891562


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 1.8459923267364502


0it [00:00, ?it/s]




MAE_loss :  1.3004571



MAE_loss :  2.933858



MAE_loss :  1.1454668



MAE_loss :  1.1790364



MAE_loss :  0.7694746



MAE_loss :  1.4386153



MAE_loss :  1.5134313



MAE_loss :  1.8297192



MAE_loss :  1.0367856



MAE_loss :  2.1297398



MAE_loss :  1.1550124



MAE_loss :  1.0429325



MAE_loss :  1.6469123



MAE_loss :  1.2561134



MAE_loss :  2.249029



MAE_loss :  1.8291329



MAE_loss :  1.1249331



MAE_loss :  1.3473752



MAE_loss :  1.0572114


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 1.8114604949951172


0it [00:00, ?it/s]




MAE_loss :  1.0163586



MAE_loss :  2.4412212



MAE_loss :  0.9567656



MAE_loss :  1.227057



MAE_loss :  1.8934678



MAE_loss :  2.024333



MAE_loss :  1.0291758



MAE_loss :  1.4257188



MAE_loss :  1.7564976



MAE_loss :  3.6083786



MAE_loss :  0.97436064



MAE_loss :  1.0143502



MAE_loss :  2.5347817



MAE_loss :  1.3384113



MAE_loss :  1.4730961



MAE_loss :  2.134297



MAE_loss :  1.7641429



MAE_loss :  1.0437697



MAE_loss :  1.0505507


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 1.752069115638733


0it [00:00, ?it/s]




MAE_loss :  1.2444706



MAE_loss :  1.2408656



MAE_loss :  1.0449063



MAE_loss :  1.691749



MAE_loss :  2.313048



MAE_loss :  1.5625137



MAE_loss :  1.4829984



MAE_loss :  1.1455686



MAE_loss :  1.7206097



MAE_loss :  1.0816519



MAE_loss :  2.4664469



MAE_loss :  1.3430352



MAE_loss :  2.022036



MAE_loss :  3.0805945



MAE_loss :  0.7312263



MAE_loss :  1.3604798



MAE_loss :  1.507252



MAE_loss :  3.4250789



MAE_loss :  1.4954002


  0%|          | 0/20 [00:00<?, ?it/s]

VALIDATION_LOSS MAE : 1.8273241519927979


0it [00:00, ?it/s]




MAE_loss :  1.4155947



MAE_loss :  2.0284042



MAE_loss :  0.86398745



MAE_loss :  1.6609185



MAE_loss :  2.3336754


KeyboardInterrupt: ignored

In [8]:
test_set = pd.read_csv('./drive/MyDrive/open_224/test_dataset/test_data.csv')
test_set['l_root'] = test_set['before_file_path'].map(lambda x: './drive/MyDrive/open_224/test_dataset/' + x.split('_')[1] + '/' + x.split('_')[2])
test_set['r_root'] = test_set['after_file_path'].map(lambda x: './drive/MyDrive/open_224/test_dataset/' + x.split('_')[1] + '/' + x.split('_')[2])
test_set['before_file_path'] = test_set['l_root'] + '/' + test_set['before_file_path'] + '.png'
test_set['after_file_path'] = test_set['r_root'] + '/' + test_set['after_file_path'] + '.png'

test_dataset = TestDataset(test_set, is_test=True)
test_data_loader = DataLoader(test_dataset,
                               batch_size=64)

In [9]:
test_value = []
with torch.no_grad():
    for test_before, test_after in tqdm(test_data_loader):
        test_before = test_before.to(device)
        test_after = test_after.to(device)
        logit = model(test_before, test_after)
        value = logit.squeeze(1).detach().cpu().float()
        
        test_value.extend(value)

  0%|          | 0/62 [00:00<?, ?it/s]

In [10]:
# submission 형식을 불러온다.
submission = pd.read_csv('./drive/MyDrive/open_224/sample_submission.csv')

# 예측한 값들은 텐서 형태로 변환 시켜준다.
predict = torch.FloatTensor(test_value)

# 음수의 값을 갖는 모든 값들을 1 Day 차이가 발생하도록 바꿔줌
temp_predict = predict.numpy()
temp_predict[np.where(temp_predict<1)] = 1

# 모델의 예측 값을 저장함
submission['time_delta'] = temp_predict
submission.to_csv('resnet50_v3.csv', index=False)