In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# 1. 데이터 불러오기

In [23]:
from google.colab import output
# !cp 파일1 파일2 # 파일1을 파일2로 복사 붙여넣기
!cp "/content/drive/MyDrive/BOAZ/dirty_mnist_2nd.zip" "dirty_mnist_2nd.zip"
# data_2.zip을 현재 디렉터리에 압축해제
!unzip "dirty_mnist_2nd.zip"

[1;30;43m스트리밍 출력 내용이 길어서 마지막 5000줄이 삭제되었습니다.[0m
 extracting: 45000.png               
 extracting: 45001.png               
 extracting: 45002.png               
 extracting: 45003.png               
 extracting: 45004.png               
 extracting: 45005.png               
 extracting: 45006.png               
 extracting: 45007.png               
 extracting: 45008.png               
 extracting: 45009.png               
 extracting: 45010.png               
 extracting: 45011.png               
 extracting: 45012.png               
 extracting: 45013.png               
 extracting: 45014.png               
 extracting: 45015.png               
 extracting: 45016.png               
 extracting: 45017.png               
 extracting: 45018.png               
 extracting: 45019.png               
 extracting: 45020.png               
 extracting: 45021.png               
 extracting: 45022.png               
 extracting: 45023.png               
 extracting: 45024.png               


In [20]:
from google.colab import output
# 현재 디렉터리에 dirty_mnist라는 폴더 생성
!mkdir "./dirty_mnist"
#dirty_mnist.zip라는 zip파일을 dirty_mnist라는 폴더에 압축 풀기
!unzip "dirty_mnist_2nd.zip" -d "./dirty_mnist/"
# 현재 디렉터리에 test_dirty_mnist라는 폴더 생성
!mkdir "./test_dirty_mnist"
#test_dirty_mnist.zip라는 zip파일을 test_dirty_mnist라는 폴더에 압축 풀기
!unzip "test_dirty_mnist_2nd.zip" -d "./test_dirty_mnist/"
# 출력 결과 지우기
output.clear()

# 2. 라이브러리 임포트

In [2]:
!pip install torchinfo

Collecting torchinfo
  Downloading https://files.pythonhosted.org/packages/4f/b1/4b310bd715885636e7174b4b52817202fff0ae3609ca2bfb17f28e33e0a1/torchinfo-0.0.8-py3-none-any.whl
Installing collected packages: torchinfo
Successfully installed torchinfo-0.0.8


In [3]:
import os
from typing import Tuple, Sequence, Callable
import csv
import cv2
import numpy as np
import pandas as pd
from PIL import Image
import torch
import torch.optim as optim
from torch import nn, Tensor
from torch.utils.data import Dataset, DataLoader
from torchinfo import summary

from torchvision import transforms
from torchvision.models import resnet50

## 1. 커스텀 데이터셋 만들기

In [4]:
class MnistDataset(Dataset):
    def __init__(
        self,
        dir,
        image_ids,
        transforms: Sequence[Callable]
    ) -> None:
        self.dir = dir
        self.transforms = transforms

        self.labels = {}
        with open(image_ids, 'r') as f:
            reader = csv.reader(f)
            next(reader)
            for row in reader:
                self.labels[int(row[0])] = list(map(int, row[1:]))

        self.image_ids = list(self.labels.keys())

    def __len__(self) -> int:
        return len(self.image_ids)

    def __getitem__(self, index: int) -> Tuple[Tensor]:
        image_id = self.image_ids[index]
        image = Image.open(
            os.path.join(
                self.dir, f'{str(image_id).zfill(5)}.png')).convert('RGB')
        target = np.array(self.labels.get(image_id)).astype(np.float32)

        if self.transforms is not None:
            image = self.transforms(image)

        return image, target

## 2. 이미지 어그멘테이션

In [5]:
transforms_train = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.485, 0.456, 0.406],
        [0.229, 0.224, 0.225]),
    transforms.RandomRotation(60, expand=False),
    transforms.RandomAffine(30)
])

transforms_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        [0.485, 0.456, 0.406],
        [0.229, 0.224, 0.225]),
    transforms.RandomRotation(60, expand=False),
    transforms.RandomAffine(30)
])

In [6]:
#trainset = MnistDataset("./dirty_mnist/", 'dirty_mnist_2nd_answer.csv', transforms_train)
#testset = MnistDataset('./test_dirty_mnist/', 'sample_submission.csv', transforms_test)

#train_loader = DataLoader(trainset, batch_size=64, num_workers=8)
#test_loader = DataLoader(testset, batch_size=32, num_workers=4)

In [9]:
trainset = MnistDataset('/content/drive/MyDrive/BOAZ/dirty_mnist_2nd/', '/content/drive/MyDrive/BOAZ/dirty_mnist_2nd_answer.csv', transforms_train)
testset = MnistDataset('/content/drive/MyDrive/BOAZ/test_dirty_mnist/', '/content/drive/MyDrive/BOAZ/sample_submission.csv', transforms_test)

train_loader = DataLoader(trainset, batch_size=64, num_workers=8)
#train_loader = DataLoader(trainset, batch_size=256, num_workers=8)
#test_loader = DataLoader(testset, batch_size=32, num_workers=4)
test_loader = DataLoader(testset, batch_size=32, num_workers=4)

## 3. ResNet50 모형

In [10]:
class MnistModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.resnet = resnet50(pretrained=True)
        self.classifier = nn.Linear(1000, 26)

    def forward(self, x):
        x = self.resnet(x)
        x = self.classifier(x)

        return x

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MnistModel().to(device)
print(summary(model, input_size=(1, 3, 256, 256), verbose=0))

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth


HBox(children=(FloatProgress(value=0.0, max=102502400.0), HTML(value='')))


Layer (type:depth-idx)                   Output Shape              Param #
├─ResNet: 1-1                            [1, 1000]                 --
|    └─Conv2d: 2-1                       [1, 64, 128, 128]         9,408
|    └─BatchNorm2d: 2-2                  [1, 64, 128, 128]         128
|    └─ReLU: 2-3                         [1, 64, 128, 128]         --
|    └─MaxPool2d: 2-4                    [1, 64, 64, 64]           --
|    └─Sequential: 2-5                   [1, 256, 64, 64]          --
|    |    └─Bottleneck: 3-1              [1, 256, 64, 64]          75,008
|    |    └─Bottleneck: 3-2              [1, 256, 64, 64]          70,400
|    |    └─Bottleneck: 3-3              [1, 256, 64, 64]          70,400
|    └─Sequential: 2-6                   [1, 512, 32, 32]          --
|    |    └─Bottleneck: 3-4              [1, 512, 32, 32]          379,392
|    |    └─Bottleneck: 3-5              [1, 512, 32, 32]          280,064
|    |    └─Bottleneck: 3-6              [1, 512, 32, 32] 

## 4. 학습하기

In [12]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MultiLabelSoftMarginLoss()

num_epochs = 30 # 10 epoch
model.train()

for epoch in range(num_epochs):
    # cuda cache 초기화
    torch.cuda.empty_cache()
    for i, (images, targets) in enumerate(train_loader):
        optimizer.zero_grad()

        images = images.type(torch.FloatTensor).to(device)
        targets = targets.type(torch.FloatTensor).to(device)

        outputs = model(images)
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()
    
        if (i+1) % 50 == 0:
            outputs = outputs > 0.5
            acc = (outputs == targets).float().mean()
            print(f'{epoch}: {loss.item():.5f}, {acc.item():.5f}')



0: 0.69430, 0.54627
0: 0.69046, 0.54447
0: 0.69043, 0.53846
0: 0.69270, 0.53365
0: 0.69127, 0.54026
0: 0.69269, 0.53065
0: 0.69159, 0.54387
0: 0.69332, 0.52885
0: 0.68882, 0.53726
0: 0.69179, 0.53606
0: 0.68844, 0.54026
0: 0.68827, 0.53606
0: 0.68815, 0.53666
0: 0.68754, 0.53846
0: 0.68874, 0.53365




1: 0.68918, 0.54808
1: 0.68554, 0.54507
1: 0.68793, 0.53846
1: 0.68891, 0.53726
1: 0.68727, 0.53966
1: 0.68868, 0.53125
1: 0.68605, 0.54447
1: 0.68862, 0.53065
1: 0.68866, 0.53786
1: 0.68824, 0.53726
1: 0.68612, 0.54267
1: 0.68857, 0.53606
1: 0.68695, 0.53666
1: 0.68752, 0.53846
1: 0.68795, 0.53245




2: 0.68763, 0.54868
2: 0.68482, 0.54507
2: 0.68536, 0.53846
2: 0.68859, 0.53726
2: 0.68491, 0.53966
2: 0.68763, 0.53065
2: 0.68628, 0.54507
2: 0.68915, 0.53065
2: 0.68537, 0.53786
2: 0.68641, 0.53906
2: 0.68285, 0.54387
2: 0.68454, 0.53606
2: 0.68640, 0.53966
2: 0.68684, 0.53666
2: 0.68628, 0.53546




3: 0.68000, 0.55529
3: 0.67294, 0.55769
3: 0.67708, 0.54748
3: 0.67547, 0.54627
3: 0.67701, 0.54928
3: 0.67300, 0.55168
3: 0.66737, 0.56370
3: 0.67743, 0.54447
3: 0.67045, 0.55409
3: 0.67599, 0.55228
3: 0.66275, 0.56550
3: 0.66630, 0.56190
3: 0.67300, 0.55950
3: 0.66732, 0.56671
3: 0.66359, 0.56671




4: 0.65745, 0.57692
4: 0.65254, 0.57933
4: 0.66416, 0.57272
4: 0.66195, 0.57091
4: 0.65789, 0.56911
4: 0.65343, 0.56791
4: 0.65006, 0.58834
4: 0.65979, 0.56130
4: 0.65544, 0.57812
4: 0.65853, 0.58714
4: 0.64647, 0.57993
4: 0.65042, 0.58173
4: 0.66019, 0.58053
4: 0.65170, 0.57873
4: 0.64474, 0.58774




5: 0.64187, 0.60216
5: 0.63465, 0.60637
5: 0.65843, 0.59315
5: 0.65140, 0.59315
5: 0.64429, 0.59435
5: 0.63466, 0.60156
5: 0.63018, 0.61659
5: 0.64914, 0.58173
5: 0.63342, 0.60998
5: 0.64627, 0.58774
5: 0.63883, 0.60637
5: 0.62759, 0.61358
5: 0.63997, 0.60337
5: 0.63050, 0.60637
5: 0.63111, 0.61418




6: 0.61944, 0.63041
6: 0.61306, 0.62620
6: 0.62813, 0.60877
6: 0.63169, 0.61599
6: 0.62769, 0.62500
6: 0.61288, 0.62380
6: 0.61965, 0.62620
6: 0.62218, 0.60817
6: 0.61174, 0.62861
6: 0.63111, 0.62079
6: 0.61983, 0.62921
6: 0.60347, 0.63462
6: 0.62506, 0.61418
6: 0.61368, 0.62620
6: 0.61414, 0.62620




7: 0.60967, 0.64423
7: 0.59411, 0.65325
7: 0.61937, 0.62019
7: 0.61208, 0.62921
7: 0.60175, 0.64844
7: 0.58896, 0.65385
7: 0.59634, 0.65685
7: 0.61514, 0.63582
7: 0.59941, 0.64423
7: 0.61218, 0.63702
7: 0.60505, 0.64964
7: 0.59570, 0.64724
7: 0.61828, 0.63942
7: 0.59942, 0.64603
7: 0.58572, 0.64663




8: 0.57603, 0.67969
8: 0.57743, 0.67368
8: 0.61256, 0.64784
8: 0.58250, 0.65805
8: 0.58414, 0.65986
8: 0.57070, 0.66707
8: 0.58791, 0.66106
8: 0.58393, 0.66526
8: 0.57749, 0.66947
8: 0.57373, 0.67308
8: 0.57574, 0.68209
8: 0.58405, 0.66887
8: 0.58320, 0.64964
8: 0.57639, 0.66526
8: 0.57331, 0.67608




9: 0.55056, 0.70373
9: 0.55660, 0.69411
9: 0.58762, 0.67668
9: 0.56923, 0.67788
9: 0.57074, 0.68750
9: 0.54306, 0.69952
9: 0.54845, 0.69411
9: 0.56592, 0.68570
9: 0.56321, 0.68149
9: 0.56051, 0.69291
9: 0.55971, 0.69111
9: 0.54587, 0.70433
9: 0.56134, 0.69471
9: 0.55847, 0.69531
9: 0.55184, 0.69171




10: 0.52253, 0.72596
10: 0.51979, 0.72416
10: 0.56812, 0.69832
10: 0.53932, 0.71575
10: 0.55307, 0.70673
10: 0.51634, 0.72957
10: 0.53496, 0.72055
10: 0.53538, 0.71575
10: 0.53006, 0.71394
10: 0.53408, 0.71695
10: 0.52608, 0.72716
10: 0.52299, 0.73558
10: 0.55579, 0.70312
10: 0.53347, 0.72416
10: 0.53480, 0.71755




11: 0.50260, 0.74279
11: 0.49269, 0.73978
11: 0.54507, 0.71815
11: 0.52237, 0.72837
11: 0.51808, 0.73377
11: 0.48754, 0.74880
11: 0.50313, 0.75661
11: 0.50768, 0.74279
11: 0.51487, 0.72716
11: 0.50833, 0.73498
11: 0.52649, 0.72776
11: 0.49800, 0.75240
11: 0.52919, 0.72837
11: 0.50816, 0.74459
11: 0.50618, 0.75240




12: 0.48523, 0.76562
12: 0.46564, 0.77524
12: 0.52435, 0.72837
12: 0.49250, 0.75661
12: 0.49789, 0.75601
12: 0.47626, 0.75661
12: 0.47898, 0.76863
12: 0.50483, 0.73918
12: 0.48812, 0.75180
12: 0.47000, 0.75541
12: 0.50255, 0.74940
12: 0.58354, 0.68510
12: 0.51807, 0.73678
12: 0.50011, 0.75180
12: 0.49180, 0.75901




13: 0.45870, 0.77764
13: 0.45305, 0.77404
13: 0.49738, 0.74820
13: 0.46456, 0.77344
13: 0.47051, 0.76923
13: 0.45048, 0.78005
13: 0.46897, 0.76743
13: 0.48339, 0.75481
13: 0.48381, 0.75661
13: 0.46062, 0.78005
13: 0.47485, 0.76562
13: 0.46883, 0.77885
13: 0.47715, 0.76022
13: 0.47949, 0.77163
13: 0.46698, 0.76142




14: 0.42518, 0.80409
14: 0.42562, 0.78966
14: 0.49822, 0.74639
14: 0.45084, 0.78245
14: 0.47022, 0.76382
14: 0.44091, 0.78245
14: 0.44795, 0.78786
14: 0.46349, 0.76983
14: 0.46596, 0.77584
14: 0.43465, 0.79147
14: 0.45363, 0.77945
14: 0.46089, 0.78425
14: 0.48413, 0.76082
14: 0.47418, 0.77704
14: 0.46208, 0.77644




15: 0.42471, 0.79748
15: 0.39079, 0.80709
15: 0.48522, 0.76022
15: 0.43839, 0.79267
15: 0.45269, 0.78305
15: 0.42623, 0.79808
15: 0.43924, 0.79147
15: 0.44892, 0.77524
15: 0.44451, 0.78305
15: 0.43993, 0.78966
15: 0.44621, 0.78486
15: 0.44730, 0.78606
15: 0.47105, 0.76863
15: 0.45080, 0.78365
15: 0.44185, 0.78606




16: 0.40062, 0.81430
16: 0.39242, 0.81010
16: 0.45972, 0.78125
16: 0.42361, 0.80228
16: 0.42656, 0.79928
16: 0.41927, 0.79688
16: 0.43417, 0.79868
16: 0.43453, 0.78606
16: 0.44300, 0.78365
16: 0.41473, 0.80108
16: 0.44213, 0.79808
16: 0.42674, 0.80409
16: 0.43146, 0.78786
16: 0.43382, 0.80048
16: 0.42989, 0.78486




17: 0.38034, 0.82512
17: 0.38636, 0.81550
17: 0.43970, 0.78365
17: 0.42604, 0.79928
17: 0.43378, 0.79387
17: 0.40395, 0.80288
17: 0.42039, 0.80769
17: 0.43786, 0.78786
17: 0.43236, 0.79748
17: 0.40937, 0.80769
17: 0.43066, 0.80288
17: 0.41004, 0.80829
17: 0.46247, 0.77464
17: 0.43280, 0.79567
17: 0.41360, 0.80409




18: 0.37607, 0.83353
18: 0.38804, 0.81971
18: 0.43192, 0.78966
18: 0.41708, 0.80709
18: 0.40210, 0.81671
18: 0.38991, 0.82212
18: 0.41778, 0.80228
18: 0.42462, 0.80108
18: 0.40964, 0.80769
18: 0.38905, 0.81550
18: 0.41211, 0.81611
18: 0.42115, 0.80649
18: 0.42589, 0.79387
18: 0.40832, 0.80409
18: 0.41485, 0.80168




19: 0.38042, 0.83353
19: 0.36781, 0.83173
19: 0.42447, 0.79748
19: 0.39023, 0.82692
19: 0.42063, 0.79748
19: 0.37409, 0.82873
19: 0.37603, 0.82572
19: 0.40297, 0.81490
19: 0.38619, 0.82091
19: 0.38575, 0.82452
19: 0.39515, 0.82031
19: 0.41343, 0.80589
19: 0.41989, 0.81070
19: 0.39906, 0.81190
19: 0.38849, 0.81971




20: 0.35773, 0.84075
20: 0.35695, 0.83293
20: 0.41866, 0.80108
20: 0.38161, 0.83113
20: 0.39062, 0.81851
20: 0.38226, 0.81971
20: 0.38834, 0.82873
20: 0.38256, 0.82572
20: 0.38954, 0.81671
20: 0.36900, 0.82752
20: 0.39908, 0.82091
20: 0.39153, 0.81971
20: 0.39021, 0.81971
20: 0.38974, 0.82031
20: 0.38312, 0.81731




21: 0.35231, 0.85397
21: 0.35145, 0.84375
21: 0.40819, 0.81070
21: 0.38178, 0.82752
21: 0.39149, 0.82151
21: 0.35817, 0.83774
21: 0.37256, 0.83233
21: 0.39225, 0.81851
21: 0.37510, 0.82933
21: 0.36033, 0.83233
21: 0.38507, 0.82692
21: 0.38059, 0.83053
21: 0.39344, 0.81010
21: 0.38664, 0.82632
21: 0.38866, 0.80649




22: 0.34554, 0.84495
22: 0.32862, 0.84916
22: 0.41351, 0.80950
22: 0.36598, 0.83594
22: 0.36278, 0.83353
22: 0.34435, 0.84856
22: 0.37845, 0.82873
22: 0.38016, 0.83594
22: 0.37322, 0.83053
22: 0.37085, 0.82813
22: 0.36250, 0.83654
22: 0.36700, 0.83173
22: 0.39184, 0.82512
22: 0.37050, 0.83654
22: 0.38642, 0.82332




23: 0.32649, 0.85757
23: 0.34004, 0.84796
23: 0.38967, 0.82572
23: 0.36078, 0.84435
23: 0.37774, 0.83534
23: 0.34227, 0.85036
23: 0.34942, 0.84675
23: 0.34906, 0.84555
23: 0.38742, 0.81791
23: 0.36697, 0.82993
23: 0.34162, 0.84435
23: 0.36931, 0.84075
23: 0.38683, 0.82332
23: 0.38107, 0.82572
23: 0.36342, 0.83053


KeyboardInterrupt: ignored

## 5. 추론하기

In [28]:
submit = pd.read_csv('/content/drive/MyDrive/BOAZ/test_dirty_mnist/sample_submission.csv')

model.eval()
batch_size = test_loader.batch_size
batch_index = 0
for i, (images, targets) in enumerate(test_loader):
    images = images.to(device)
    targets = targets.to(device)
    outputs = model(images)
    outputs = outputs > 0.5
    batch_index = i * batch_size
    submit.iloc[batch_index:batch_index+batch_size, 1:] = \
        outputs.long().squeeze(0).detach().cpu().numpy()
    
submit.to_csv('/content/drive/MyDrive/Resnet50_epoch30_batch64_rotation_affine_submit.csv', index=False)

FileNotFoundError: ignored