# 중간고사 3번: 얼굴 분류
*Multi-class classification*   
`150`개의 -1~1 사이의 값을 갖는 픽셀 정보를 입력받아 7명의 사람 중 누구의 얼굴인지 분류하는 문제이다.

- CUDA를 사용하도록 했다.
- cross-entropy를 내장 함수로 바꿨다.
- 학습시에 momentum을 `0.9`로 적용하였다.
- bias를 클래스 개수만큼 만들었다.

In [0]:
from google.colab import files
files.upload()

In [0]:
! mkdir -p ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json
! kaggle competitions download -c 2020-ai-exam-facepca-revisit

In [0]:
! unzip 2020.AI.facePCA-train.csv.zip

In [0]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import pandas as pd
import numpy as np

In [6]:
torch.cuda.is_available() # GPU 가용여부 확인

True

In [0]:
torch.manual_seed(1)
device = torch.device('cuda')

# Load Data

In [14]:
xy_train = pd.read_csv('2020.AI.facePCA-train.csv', header=0, index_col=0)
xy_train.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,0.1
0,-2.075606,-1.04579,2.126936,0.036825,-0.757574,-0.517365,0.855506,1.051939,0.457736,0.01348,-0.039627,0.638728,0.481672,2.337836,1.778472,0.133097,-2.271315,-4.45689,2.097818,-1.137919,0.18844,-0.335002,1.125455,-0.324029,0.140952,1.076946,0.758809,-0.099774,3.11996,0.883766,-0.893408,1.159581,1.430616,1.685677,1.343437,-1.25912,-0.639151,-2.336286,-0.013655,-1.46387,...,-0.746482,1.436504,-1.175173,-0.082779,2.073651,-2.109485,0.351634,-1.139726,-0.081991,-0.444308,2.042309,1.20041,0.04101,0.86173,0.868629,1.227638,0.525842,0.24122,0.595716,0.691755,-1.140556,0.365722,0.557838,0.440183,0.86373,0.326829,-1.658824,0.59499,-0.268711,0.895182,0.76697,-0.424478,-0.124687,-1.496749,0.447682,0.436117,0.456781,-0.871528,2.808375,3
1,1.321112,0.592836,0.534154,0.12266,1.182957,-0.673364,-0.182102,1.064393,0.87006,0.442813,-0.75043,-1.227783,0.513912,-0.67168,-0.34917,-0.063715,1.130124,-0.417614,0.404126,0.449445,-1.333844,0.862824,-0.222651,-0.925683,0.68668,0.034574,-0.543629,-0.08365,0.016256,-0.604505,-0.147177,-0.450206,-0.780342,0.317297,1.143763,0.946406,1.490752,-0.36354,-0.223404,-0.311723,...,0.813659,-0.978788,0.433176,0.339878,-0.710201,0.041666,-1.838961,0.127895,-0.153543,0.936536,0.743923,-0.262089,0.906066,-0.149071,-0.21751,-0.718962,-0.084085,0.218937,0.67265,-0.56744,-0.297983,-0.052386,-0.42778,0.338658,0.518092,-0.125759,-0.312491,-1.245005,0.178577,0.71995,0.374759,-0.317582,-0.199934,-0.573055,0.597456,-0.123165,1.199251,-0.920927,1.424777,1
2,-0.761193,-0.01973,-0.239907,0.499094,1.304381,-0.561011,0.069747,1.62091,0.118996,0.400202,-0.51622,-0.072791,-0.071298,0.99998,-0.235631,-1.66343,-0.438474,-0.545752,-0.116703,0.786574,-0.723442,0.821037,-0.348752,-0.945605,-0.40521,0.248243,0.049621,-0.087621,0.227024,-0.638978,-0.30192,-1.22479,1.076064,-1.048868,-0.405282,-0.455682,-0.202171,0.026434,-0.393673,0.133989,...,-0.287399,0.146545,0.26978,-1.803792,-0.422592,-0.170472,-0.468339,-0.952578,0.444236,-0.099052,-1.221919,-0.183294,0.086852,-0.830793,0.640447,1.003492,-0.425199,-0.053831,-0.703356,-0.535291,-0.485399,1.37237,-0.481861,0.453103,-0.381302,-1.434609,0.569906,0.035467,-1.121713,-0.165375,-1.785485,0.708388,0.00103,-0.01527,1.503353,0.86729,1.289758,-0.995063,-0.750737,2
3,-0.117408,0.116545,-0.009745,2.104061,-0.549831,0.623429,0.885584,0.496275,-0.346002,-0.173676,0.042416,1.845655,0.364923,-0.316649,-0.601811,-0.127524,-0.552606,0.431029,1.932828,0.395474,0.341249,0.228068,-0.665014,-0.280454,1.448531,0.245143,-0.540567,0.072655,-0.879966,0.248609,0.874784,-0.902228,0.742393,0.151671,1.085296,-0.238937,-0.543449,-0.90594,0.218637,1.405076,...,1.418218,0.4883,-0.131573,0.37477,-0.487221,1.047709,-1.006666,2.041681,0.234395,2.235936,2.361929,-2.05524,-0.475777,3.257405,-0.466937,0.226311,-2.482728,0.732327,0.246419,-0.723455,1.962168,-1.237284,1.671197,0.407076,-1.346458,-0.141933,-0.336966,-1.294388,0.745581,-0.042589,-1.107025,-0.31411,1.143012,1.028699,0.271098,0.230415,0.643734,-2.632324,-0.415003,6
4,-0.39637,0.426845,-0.250984,-0.651813,1.795055,0.917808,0.349124,-1.391063,1.175008,1.097817,-0.181041,0.567281,-0.279172,-0.215811,0.176151,0.854819,0.321899,0.723485,1.496162,0.863732,0.508989,-0.463177,1.208637,-2.700064,-0.569516,0.452944,0.219559,0.092291,0.540682,-0.614767,-2.474989,2.133385,0.243832,-0.179368,1.228117,-0.689083,-1.582876,-1.957578,-0.149797,1.220719,...,0.056493,0.067192,0.286282,0.217141,0.757476,0.082412,-0.6435,-0.324384,-0.469117,0.401638,1.753386,2.410296,0.427339,-0.806103,-0.636527,-0.461202,0.218198,-0.949156,-1.500684,1.138279,-0.481768,0.010909,-0.207365,-2.179287,-1.048523,0.398689,0.210045,-0.523262,0.201882,1.491359,0.595545,-0.752304,-0.414062,0.662435,0.014245,0.087689,0.695441,0.475167,-0.314527,4


In [19]:
xy_train.iloc[:, 150].max() # class 0 ~ 6, 7 Categories

6

In [20]:
xy_train.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 966 entries, 0 to 965
Columns: 151 entries, 0 to 0.1
dtypes: float64(150), int64(1)
memory usage: 1.1 MB


In [21]:
x_train = np.array(xy_train.iloc[:, :150])
y_train = np.array(xy_train.iloc[:, 150])

x_train = torch.FloatTensor(x_train).to(device) # CUDA 연결
y_train = torch.LongTensor(y_train).to(device)  # CUDA 연결

print(x_train.shape)
print(y_train.shape)

torch.Size([966, 150])
torch.Size([966])


# Train Model

In [23]:
Epochs = 10000
lr = 1e-3
nb_class = 7
nb_data = len(y_train) #966

W = torch.zeros((150, nb_class), requires_grad=True, device='cuda') # CUDA 연결
b = torch.zeros((nb_class), requires_grad=True, device='cuda')      # CUDA 연결

optimizer = optim.SGD((W, b), lr=lr, momentum=0.9) #Implement Momentum

for epoch in range(1, Epochs+1):
    cost = F.cross_entropy(x_train.matmul(W) + b, y_train)

    optimizer.zero_grad()
    cost.backward()
    optimizer.step()

    if epoch == 1 or epoch%1000 == 0:
        print('Epoch: {:4d}/{} Cost: {:.6f}'.format(epoch, Epochs, cost))

Epoch:    1/10000 Cost: 1.945921
Epoch: 1000/10000 Cost: 0.431529
Epoch: 2000/10000 Cost: 0.278477
Epoch: 3000/10000 Cost: 0.215607
Epoch: 4000/10000 Cost: 0.179142
Epoch: 5000/10000 Cost: 0.154578
Epoch: 6000/10000 Cost: 0.136596
Epoch: 7000/10000 Cost: 0.122725
Epoch: 8000/10000 Cost: 0.111631
Epoch: 9000/10000 Cost: 0.102519
Epoch: 10000/10000 Cost: 0.094880


# Test

In [25]:
x_test = pd.read_csv('2020.AI.facePCA-test.csv', index_col=0)
x_test

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149
0,-1.375219,-1.845628,-0.925274,-0.158622,-0.121293,0.636925,1.737201,-0.191841,-0.257884,0.527665,1.163346,-0.639700,0.544135,-0.601390,-0.368683,-0.190895,1.221793,2.728683,0.476333,-0.051149,-0.949190,-0.208254,-1.031771,0.308392,1.190954,1.163332,0.559764,-0.468263,-1.841640,0.266038,-0.697456,1.124188,-0.755710,0.707862,0.096643,-0.355635,-0.093965,1.587871,0.222673,-0.481568,...,-0.770736,0.161240,-0.846696,-0.283681,-0.941492,-0.885786,-0.245364,-0.881790,-0.573770,1.145700,0.922978,-0.913089,-1.027131,-0.567909,1.344121,-0.771880,0.890277,0.307017,-0.854513,1.304728,-0.735286,-1.614669,1.585679,2.331472,-0.684419,-0.317178,0.269693,1.084765,-0.461798,-0.161737,0.344800,-1.267994,1.672579,0.147600,-0.856595,1.012568,-0.805418,0.856378,-0.194109,-1.129755
1,-0.818803,1.519286,-0.682806,1.076556,0.181734,-0.810429,0.400994,0.498940,-0.726013,0.260842,-0.573281,-0.530764,0.664593,0.472009,0.259885,0.244592,1.076785,1.422447,-1.712762,-1.030667,-0.973780,-0.446320,1.865781,-0.121377,0.445549,-0.104164,-0.348535,-0.799038,-0.042183,0.360867,-0.704393,-1.278867,-0.105113,-0.320782,0.054084,0.255151,0.338615,0.610291,-0.936842,-0.025065,...,-1.883699,-0.447087,-1.080929,-1.422980,-0.452983,0.796951,-0.350637,1.251745,-0.828453,0.062714,0.519164,0.042157,0.910245,-0.523560,0.655003,0.136690,-0.268390,0.980128,0.274535,-1.130295,2.181644,-0.718002,-0.519800,-0.106526,-0.944108,-0.378691,0.005520,-0.063296,-1.080145,0.097826,1.437538,-0.135571,0.900427,-0.089965,-0.686085,0.536571,2.142441,0.005431,0.762059,-0.096975
2,-0.869844,-0.292945,-1.227929,-0.297976,-0.664880,-0.939109,0.013387,-0.727302,1.356753,-1.144051,-1.031264,0.453598,-0.827849,0.397944,0.437201,0.286109,-0.149297,0.164007,-0.324092,0.305691,-0.325101,1.259338,-1.044397,-0.518521,-1.055638,0.562607,2.704049,-0.084681,-0.847506,-0.363796,-0.017880,-0.173644,-0.279630,0.522414,0.793634,0.821183,-0.365311,0.619285,-0.292250,-0.005773,...,1.031070,-0.293566,0.229420,-0.586737,-0.133228,1.407845,-0.067706,-0.744287,-1.280458,-0.309188,-0.009044,-0.717974,-0.006219,0.391133,0.983139,0.185951,-0.022059,0.320318,-0.306566,0.694997,-1.030008,-0.940314,-1.507771,0.956281,1.834105,0.161191,0.287243,0.401221,0.197643,0.035524,-0.951785,1.915643,-1.640670,-1.206601,1.957250,-1.210786,0.292856,1.624620,-0.087121,1.696381
3,1.363942,-0.307408,0.980014,2.453974,0.218842,1.421033,-0.989311,-0.071045,-0.392554,-0.765995,1.577204,-0.066864,-0.182721,-0.465569,-1.550565,-0.120630,0.627956,0.280492,0.194193,-0.816669,0.349668,-1.763881,-0.980113,1.049254,-0.610803,1.278509,0.587294,-0.381649,-0.636837,0.522191,-0.641655,0.682473,-0.756016,0.746089,-0.107582,-1.179574,0.353339,0.030155,-0.675114,-1.476664,...,-0.513370,0.543065,-0.045438,-0.729456,0.357642,0.304425,0.661445,1.054563,0.830678,0.212049,0.184344,0.658019,-0.427653,0.723826,0.743508,-0.206367,0.415300,0.623760,-0.563690,-0.051167,-0.360780,-0.223417,0.830155,1.612943,-1.332355,0.165729,1.366950,0.627041,-0.102356,1.293533,0.002914,0.629358,-0.525619,1.658871,-0.711543,0.415380,0.754039,0.117611,-0.354967,-0.280746
4,-0.719613,-0.309311,-0.500165,0.100551,0.069662,1.058323,-0.884641,-0.331464,-1.116771,-1.037474,0.562824,0.811595,0.395730,-0.203637,-0.428481,0.627088,0.035763,-0.215167,0.724704,0.831763,-0.275152,0.083035,0.090208,0.108270,0.375636,0.175571,-0.310733,0.776454,-1.224595,0.193453,-0.618811,-0.267694,-1.037601,0.729905,0.327604,-1.449758,0.395276,-0.369709,0.659573,-1.209710,...,1.089238,0.943632,-0.590246,-0.153363,-0.663677,0.764101,-1.097823,-0.159946,0.319659,-1.526523,0.875131,0.774221,2.778785,-0.717434,-0.252040,0.411115,0.800209,-0.168334,1.001572,-0.514143,-0.541976,-0.266132,-0.669099,-0.697428,-0.316703,0.825868,-0.959087,0.676842,0.795356,0.410600,0.002025,-0.182412,0.087406,0.204192,-0.686670,1.210678,-1.746604,-0.316461,0.035818,-1.074916
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
317,0.803509,0.684933,-0.683767,-0.108604,-0.155155,-1.013400,1.894437,-0.983844,0.067618,0.175858,-0.018636,-1.484696,0.002606,-1.458824,-0.406325,-1.032493,0.774735,0.494453,1.332354,-1.897409,-0.793456,0.561186,-0.093244,-0.127169,-1.413250,-0.454358,-0.121798,-1.582177,-0.095422,-0.195259,0.118195,-2.591806,-0.739144,0.123858,-0.655811,-0.407243,0.422840,-0.224825,-0.723814,-1.299342,...,2.006737,0.684697,1.490928,-0.134988,-1.010003,-0.380909,-1.570750,0.809742,-0.545256,0.495985,-0.356967,0.335222,0.592199,0.347303,1.025782,-0.775210,1.253623,1.325518,-1.073384,-0.676411,-0.196549,0.038047,-1.241676,-0.894697,1.992848,-0.710599,-1.530570,0.027434,-0.953765,-0.021747,0.811345,-0.510938,-1.107407,0.581975,0.808569,1.069255,-1.152038,-1.984130,0.500089,0.910989
318,0.725379,-2.876888,0.809781,1.454774,0.224179,0.831514,0.156112,-0.805710,0.270198,0.490332,0.774472,-0.334351,-0.284235,0.136382,-0.175447,0.156609,1.260233,-0.234391,-0.372286,1.426326,-1.033333,-0.228425,-2.123115,1.813649,-0.544391,0.996927,-1.186220,0.703614,1.109814,0.884458,-0.937957,0.996176,0.589331,0.453303,-1.120939,0.417860,-0.877057,0.815737,-1.675642,-0.607418,...,0.219238,0.011264,1.237881,-0.814693,0.920353,0.683985,-0.320586,-0.304908,0.968788,0.006397,0.346135,1.098846,-0.248289,0.246305,0.174223,0.212435,1.454289,-0.483447,1.257194,-0.399061,0.055828,-0.071062,-0.448349,1.616799,-2.690924,0.383301,-0.744056,0.341023,-0.157247,-1.391594,-1.014993,0.338920,-0.022514,0.026659,0.548030,-0.310831,0.207058,0.440281,0.511577,0.628985
319,0.153744,-0.717782,0.838810,0.509070,-0.574993,-1.663966,-1.339896,1.098158,0.817291,0.850823,-2.732243,-0.068814,-0.118254,-0.430479,-0.167424,-0.460925,-0.323528,0.241845,0.307938,-0.830338,-0.514417,0.244866,0.685637,1.007379,1.068363,0.433442,0.498509,0.373885,-0.754033,-0.201994,-1.280866,-0.215995,1.421019,-0.373411,-0.974327,0.521906,0.243701,-0.566726,-0.338031,1.330461,...,1.421779,-0.918042,1.426771,-0.337534,0.493950,-0.441272,0.073341,0.929493,0.685699,-1.071399,0.538115,-1.836118,0.753503,0.185836,0.171234,0.443478,0.202768,0.458904,-0.526348,-0.753682,-2.474964,1.619320,-0.631002,-0.407847,0.978093,-0.185852,-0.441703,1.298674,-0.077299,1.006483,-0.766445,-0.469357,0.497239,-0.492126,-0.602005,-1.534660,0.545223,-1.987615,-0.211417,0.904811
320,0.057009,0.483827,-0.153277,2.636805,-1.084646,-0.791640,-0.271040,1.022719,0.418752,-1.707415,-1.636300,2.031785,0.581914,-1.596544,0.017216,-0.605585,0.522210,0.326285,1.232036,-0.093390,1.648812,-0.061472,0.650171,-1.194274,0.389265,-1.597134,-0.183011,0.551465,0.469612,0.829700,-0.947815,0.340434,-1.353212,1.159065,1.748087,-0.289025,0.405403,-0.794390,-2.061640,-0.327438,...,-0.850960,-0.703638,0.070207,1.443368,-0.661625,-0.844443,-0.598602,-0.601372,-0.498278,-2.365959,-1.105482,2.740462,0.047426,-0.221524,-0.791837,1.313484,1.022207,0.472421,-0.900083,-1.009877,0.873272,-0.351226,-0.535431,-1.692321,-0.497642,-0.551513,-0.085840,-1.727797,0.582926,-0.354644,-0.506249,-0.480262,-0.454741,1.076510,0.638840,-0.602263,-0.969037,1.219474,-0.063797,-0.820716


In [0]:
x_test = np.array(x_test)
x_test = torch.FloatTensor(x_test)

In [38]:
H = F.softmax(x_test.matmul(W)+b, dim=1)
predict = torch.argmax(H, dim=1)
predict

tensor([3, 3, 6, 3, 3, 3, 4, 1, 3, 3, 3, 3, 3, 6, 3, 3, 3, 1, 3, 4, 1, 3, 3, 3,
        0, 1, 0, 3, 3, 3, 2, 1, 3, 3, 3, 3, 3, 1, 3, 3, 3, 1, 3, 1, 1, 1, 4, 3,
        2, 3, 3, 3, 3, 3, 6, 2, 1, 3, 5, 3, 1, 1, 0, 4, 2, 5, 6, 4, 1, 3, 4, 6,
        3, 3, 3, 2, 1, 6, 4, 4, 4, 0, 4, 3, 3, 3, 5, 3, 3, 2, 3, 6, 3, 1, 1, 6,
        1, 1, 6, 6, 3, 1, 3, 1, 3, 1, 3, 3, 3, 3, 4, 1, 3, 3, 3, 1, 3, 4, 1, 3,
        1, 3, 3, 0, 3, 4, 4, 3, 1, 3, 6, 6, 6, 3, 2, 4, 3, 3, 1, 6, 2, 2, 5, 1,
        3, 6, 1, 3, 6, 1, 1, 1, 1, 3, 3, 3, 6, 1, 1, 1, 6, 5, 5, 1, 3, 1, 5, 1,
        2, 3, 3, 1, 6, 1, 5, 1, 3, 2, 2, 1, 3, 3, 3, 2, 3, 3, 3, 3, 3, 2, 3, 2,
        3, 3, 6, 3, 3, 6, 3, 6, 3, 2, 1, 2, 3, 1, 6, 2, 0, 2, 3, 4, 3, 3, 3, 3,
        3, 2, 3, 1, 2, 3, 1, 1, 6, 3, 3, 3, 1, 3, 3, 3, 1, 0, 3, 1, 6, 3, 4, 3,
        3, 4, 2, 4, 3, 0, 3, 3, 3, 4, 4, 3, 2, 4, 3, 4, 2, 1, 6, 3, 2, 3, 2, 1,
        3, 6, 1, 1, 3, 6, 1, 1, 3, 3, 4, 3, 3, 3, 3, 3, 1, 0, 3, 3, 1, 0, 3, 3,
        3, 4, 4, 3, 5, 1, 2, 1, 4, 5, 3,

In [0]:
id = np.array([i for i in range(len(x_test))]).reshape(-1,1)
Category =  predict.detach().numpy().reshape(-1, 1) # detach로 연산 기록 종료 및 분리,
result = np.hstack((id, Category))

submit = pd.DataFrame(result, columns=('id', 'Category'))
submit.to_csv('submit.csv', index=False, header=True)

In [39]:
! kaggle competitions submit -c 2020-ai-exam-facepca-revisit -f submit.csv -m 'First Try'

100% 1.79k/1.79k [00:04<00:00, 380B/s]
Traceback (most recent call last):
  File "/usr/local/bin/kaggle", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python2.7/dist-packages/kaggle/cli.py", line 64, in main
    print(out, end='')
UnicodeEncodeError: 'latin-1' codec can't encode characters in position 34-37: ordinal not in range(256)
