In [15]:
import pandas as pd
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import linear_kernel
import time

In [16]:
df = pd.read_csv('../DataIO/data/description.csv')

In [18]:
#this block trains the model

def train(df):
    start_time = time.time()
    tf = TfidfVectorizer(analyzer='word', ngram_range=(1, 3), min_df=0, stop_words='english')
    tfidf_matrix = tf.fit_transform(df['description'])
    cosine_similarities = linear_kernel(tfidf_matrix, tfidf_matrix)
    results = {}
    for idx, row in df.iterrows():
        similar_indices = cosine_similarities[idx].argsort()[:-100:-1]
        similar_items = [(cosine_similarities[idx][i], df['id'][i]) for i in similar_indices]
    
        # First item is the item itself, so remove it.
        # Each dictionary entry is like: [(1,2), (3,4)], with each tuple being (score, item_id)
        results[row['id']] = similar_items[1:]
    print("Training completed. It takes %s seconds" % (time.time() - start_time))
    return results

results = train(df)

Training completed. It takes 4.348368406295776 seconds


In [19]:
results

{2175: [(0.02938901731817859, 1935),
  (0.024998440208386513, 1659),
  (0.02289063672937121, 1157),
  (0.019027029520670796, 1473),
  (0.01822233520855849, 63),
  (0.017611903870247393, 698),
  (0.016519255861525212, 504),
  (0.015708490478877533, 1462),
  (0.015497670497522306, 1646),
  (0.01446396906055845, 797),
  (0.013927082925762693, 314),
  (0.013475033485968015, 1181),
  (0.012870239932950084, 1826),
  (0.01286126970190605, 1488),
  (0.012689800538127108, 2190),
  (0.012229943617628133, 647),
  (0.012130680667092431, 825),
  (0.011537215905871683, 2032),
  (0.011316985455893453, 911),
  (0.010790333958171689, 681),
  (0.010790333958171689, 702),
  (0.01053775492425508, 339),
  (0.01040035411882649, 2156),
  (0.010313242460143122, 1509),
  (0.010091424181138498, 975),
  (0.01001183526886754, 323),
  (0.009735616322502149, 1399),
  (0.009573492765948727, 854),
  (0.009516485069285568, 2299),
  (0.009399670674831412, 1484),
  (0.00925130926958356, 2192),
  (0.008946792502157713, 8

In [21]:
#result is a dictionary
#key is product id
#value is a list of pair(score, product id)
#save this to csv
with open('itemSimilarity.csv', 'wb') as csv_file:
    writer = csv.writer(csv_file)
    for key, value in mydict.items():
       writer.writerow([key, value])

In [22]:
df1

Unnamed: 0,2175,1972,1974,1971,2270,1979,1973,2271,2272,1976,...,1942,1943,1944,1953,1964,1954,1950,1951,1961,1958
0,"(0.02938901731817859, 1935)","(0.04511711705440344, 1042)","(0.9670357174448646, 2048)","(0.17348646102820342, 847)","(0.030571859899593143, 813)","(0.03863263990933129, 1683)","(0.1543490006848947, 1687)","(0.21248155016599696, 2203)","(0.05797978561757218, 703)","(0.020749316247498546, 1486)",...,"(0.16345719454687074, 175)","(0.21587495723120592, 274)","(0.28372657982393296, 103)","(0.09298021561392532, 1109)","(0.051742547588820006, 1956)","(0.9999999999999996, 2339)","(0.05090470112286692, 2332)","(0.12642223185473617, 2229)","(0.03085163757281484, 877)","(0.04921526250805688, 387)"
1,"(0.024998440208386513, 1659)","(0.0390757258341523, 1513)","(0.3855910628046723, 1975)","(0.13193059506681823, 1031)","(0.02335731790389501, 2137)","(0.029301282171621746, 1337)","(0.09142784565572362, 1949)","(0.11690127908958532, 54)","(0.0544404055215618, 711)","(0.020667436140848577, 462)",...,"(0.10578064534555438, 884)","(0.14511846154974314, 341)","(0.28073549078866394, 1256)","(0.09298021561392532, 693)","(0.048902329963978274, 1040)","(0.04636558327513157, 151)","(0.039713525022801024, 1749)","(0.05253966404381277, 86)","(0.030352100511519542, 876)","(0.04289122382184189, 133)"
2,"(0.02289063672937121, 1157)","(0.0365307755459007, 464)","(0.02968656392096251, 2114)","(0.09755368319870678, 793)","(0.01998712206929655, 836)","(0.02779535874614026, 2279)","(0.07457277559634051, 2121)","(0.11200158810913341, 1812)","(0.045574149397500464, 2033)","(0.01982000353348326, 2189)",...,"(0.1013766630002343, 816)","(0.13657091348761635, 232)","(0.2447855346844887, 222)","(0.05648286931584814, 2127)","(0.028950906959553078, 757)","(0.03457432810596724, 792)","(0.03915270068187415, 608)","(0.043764476242763466, 261)","(0.028343457426327285, 67)","(0.03790884945631879, 472)"
3,"(0.019027029520670796, 1473)","(0.03535424101721797, 952)","(0.026808347816820566, 1099)","(0.08926038340833466, 999)","(0.016264019679297158, 2295)","(0.02683964866865183, 2245)","(0.06441628558021398, 1386)","(0.10328532956443287, 1114)","(0.04296778954104898, 2271)","(0.018423057378335522, 669)",...,"(0.09650124120625553, 167)","(0.10697074774244876, 1582)","(0.24334797309178974, 104)","(0.0564258404105831, 1552)","(0.028939356086260474, 94)","(0.03386276971433205, 282)","(0.03492686111253136, 2064)","(0.04256087160270386, 259)","(0.02606999465213723, 2041)","(0.035159702465144584, 1443)"
4,"(0.01822233520855849, 63)","(0.03416617276001221, 2168)","(0.025805846322091534, 588)","(0.07506923123049454, 1931)","(0.015736570871587134, 572)","(0.02511947535211167, 162)","(0.046622283577418676, 2042)","(0.08189037202390217, 1660)","(0.04031863440404446, 527)","(0.01761014481346515, 208)",...,"(0.08989224336064552, 93)","(0.09528024948013528, 2358)","(0.2250521242320563, 2084)","(0.051122042369479294, 1208)","(0.027116419818510196, 346)","(0.03146968593520918, 1215)","(0.0347002518630949, 2244)","(0.036524288818385195, 859)","(0.021160678057894095, 1202)","(0.0304994104946854, 1537)"
5,"(0.017611903870247393, 698)","(0.03239289420381071, 587)","(0.024367228999377686, 2039)","(0.06543539439784471, 1040)","(0.012995586621348067, 1697)","(0.0248777598213671, 1217)","(0.040957399409120056, 1154)","(0.07404640091164398, 1729)","(0.036805952928875715, 1813)","(0.017305516654679665, 1866)",...,"(0.07872456412270365, 239)","(0.08296205929048829, 91)","(0.20945658026165678, 251)","(0.04598478161215506, 1502)","(0.026468041649109195, 2094)","(0.03146968593520918, 2224)","(0.030951897387871477, 998)","(0.03515382446775215, 727)","(0.020299334191926425, 5)","(0.02580279622373636, 2005)"
6,"(0.016519255861525212, 504)","(0.0321438041182633, 1688)","(0.023774561598530745, 1480)","(0.061807199328593104, 996)","(0.012995586621348067, 1758)","(0.023930159973509297, 2154)","(0.03775007460394709, 317)","(0.06702798146055758, 1527)","(0.03584808763067289, 639)","(0.01713583387421744, 2181)",...,"(0.06415028259027113, 2103)","(0.06675261572460578, 334)","(0.20746301315804855, 147)","(0.04124804985580676, 1621)","(0.02606481328276187, 1883)","(0.03146968593520918, 1158)","(0.029056529723041576, 1842)","(0.0315793754273709, 1760)","(0.019221482413866314, 1270)","(0.023998420412202525, 2298)"
7,"(0.015708490478877533, 1462)","(0.030042305380980946, 812)","(0.023529304516730663, 270)","(0.0593366694804334, 1010)","(0.012763979513732707, 1911)","(0.023432280662033532, 235)","(0.03688155994035296, 1485)","(0.06688036745723634, 577)","(0.035727217068291725, 330)","(0.017063212515787827, 299)",...,"(0.05960958551948801, 574)","(0.06509343420743331, 1840)","(0.1987049738885214, 1899)","(0.04064489535726815, 1884)","(0.02509167005186339, 235)","(0.028686127127348006, 1559)","(0.027909161712528656, 150)","(0.029091340279254634, 183)","(0.018264607286239968, 695)","(0.02133960450112304, 1129)"
8,"(0.015497670497522306, 1646)","(0.029503792638115874, 446)","(0.02330443309789009, 867)","(0.05078403281587786, 741)","(0.011636117471639068, 1884)","(0.02305211911938919, 314)","(0.03686443692874891, 1737)","(0.06537958960110755, 639)","(0.03545381969561998, 699)","(0.016385651392887914, 989)",...,"(0.05420322372606797, 847)","(0.06295476979883338, 335)","(0.18099710287115506, 1045)","(0.03995469955054477, 1130)","(0.024524312421092872, 2082)","(0.026685126662451715, 1808)","(0.02678272958728631, 2237)","(0.028428876401448606, 1332)","(0.01800145095957588, 719)","(0.02133960450112304, 1521)"
9,"(0.01446396906055845, 797)","(0.028978222154035026, 2130)","(0.021889878318633944, 536)","(0.04249279359897137, 740)","(0.011583602344549401, 1755)","(0.02178727159106417, 2123)","(0.03523444789828577, 1538)","(0.06510997655771154, 1617)","(0.031316750178647906, 1847)","(0.01621553287884531, 717)",...,"(0.05113812389435384, 1037)","(0.06191486611264406, 1885)","(0.17704014945775137, 256)","(0.03730472113213243, 1111)","(0.02214769238420224, 1479)","(0.02658054381457483, 851)","(0.026485530968154293, 1422)","(0.02783734701003402, 125)","(0.017638795026916978, 1629)","(0.018891501739819454, 1285)"
