Permalink
Browse files

Built a new wrapper around reflection calls to unify the process of m…

…apping types to tables. Tests added.

git-svn-id: http://sqlite-net.googlecode.com/svn/trunk@25 fb26c452-ae10-11de-9734-99f05f10ca21
  • Loading branch information...
1 parent ee23af2 commit 6417d460f96d4b96e25eb963feec17f61105cbc9 fak@praeclarum.org committed Jan 23, 2010
Showing with 279 additions and 44 deletions.
  1. +1 −1 Makefile
  2. +1 −0 SQLite.Tests/SQLite.Tests.csproj
  3. +20 −0 SQLite.sln
  4. +147 −43 src/SQLite.cs
  5. +67 −0 tests/InsertTest.cs
  6. +43 −0 tests/SQLite.Tests.csproj
View
@@ -36,4 +36,4 @@ dist:
rm -Rf $(DIST)/examples/StocksTouch/*.userprefs
rm -Rf $(DIST)/.DS_Store
zip -9 -r $(DIST).zip $(DIST)
-
+ rm -Rf $(DIST)
@@ -0,0 +1 @@
+<?xml version="1.0" encoding="utf-8"?><Project DefaultTargets="Build" ToolsVersion="3.5" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> <PropertyGroup> <Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration> <Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform> <ProductVersion>9.0.21022</ProductVersion> <SchemaVersion>2.0</SchemaVersion> <ProjectGuid>{6947A8F1-99BE-4DD1-AD4D-D89425CE67A2}</ProjectGuid> <OutputType>Library</OutputType> <RootNamespace>SQLite.Tests</RootNamespace> <AssemblyName>SQLite.Tests</AssemblyName> <TargetFrameworkVersion>v3.5</TargetFrameworkVersion> </PropertyGroup> <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|AnyCPU' "> <DebugSymbols>true</DebugSymbols> <DebugType>full</DebugType> <Optimize>false</Optimize> <OutputPath>bin\Debug</OutputPath> <DefineConstants>DEBUG</DefineConstants> <ErrorReport>prompt</ErrorReport> <WarningLevel>4</WarningLevel> <ConsolePause>false</ConsolePause> </PropertyGroup> <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|AnyCPU' "> <DebugType>none</DebugType> <Optimize>false</Optimize> <OutputPath>bin\Release</OutputPath> <ErrorReport>prompt</ErrorReport> <WarningLevel>4</WarningLevel> <ConsolePause>false</ConsolePause> </PropertyGroup> <ItemGroup> <Reference Include="System" /> </ItemGroup> <Import Project="$(MSBuildBinPath)\Microsoft.CSharp.targets" /></Project>
View
@@ -0,0 +1,20 @@
+
+Microsoft Visual Studio Solution File, Format Version 10.00
+# Visual Studio 2008
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SQLite.Tests", "tests\SQLite.Tests.csproj", "{6947A8F1-99BE-4DD1-AD4D-D89425CE67A2}"
+EndProject
+Global
+ GlobalSection(SolutionConfigurationPlatforms) = preSolution
+ Debug|Any CPU = Debug|Any CPU
+ Release|Any CPU = Release|Any CPU
+ EndGlobalSection
+ GlobalSection(ProjectConfigurationPlatforms) = postSolution
+ {6947A8F1-99BE-4DD1-AD4D-D89425CE67A2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {6947A8F1-99BE-4DD1-AD4D-D89425CE67A2}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {6947A8F1-99BE-4DD1-AD4D-D89425CE67A2}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {6947A8F1-99BE-4DD1-AD4D-D89425CE67A2}.Release|Any CPU.Build.0 = Release|Any CPU
+ EndGlobalSection
+ GlobalSection(MonoDevelopProperties) = preSolution
+ StartupItem = tests\SQLite.Tests.csproj
+ EndGlobalSection
+EndGlobal
View
@@ -45,16 +45,17 @@ public class SQLiteConnection : IDisposable
{
private IntPtr _db;
private bool _open;
+ Dictionary<Guid, TypeTableMapping> _mappings = null;
- public string Database { get; set; }
+ public string DatabasePath { get; set; }
public bool Trace { get; set; }
- public SQLiteConnection (string database)
+ public SQLiteConnection (string databasePath)
{
- Database = database;
- var r = SQLite3.Open (Database, out _db);
+ DatabasePath = databasePath;
+ var r = SQLite3.Open (DatabasePath, out _db);
if (r != SQLite3.Result.OK) {
- throw SQLiteException.New (r, "Could not open database file: " + Database);
+ throw SQLiteException.New (r, "Could not open database file: " + DatabasePath);
}
_open = true;
}
@@ -93,14 +94,40 @@ public int CreateTable<T> ()
return count;
}
+
+ public TypeTableMapping GetMapping(Type type) {
+ if (_mappings == null) {
+ _mappings = new Dictionary<Guid, TypeTableMapping>();
+ }
+ TypeTableMapping map;
+ if (!_mappings.TryGetValue(type.GUID, out map)) {
+ map = new TypeTableMapping(type);
+ _mappings[type.GUID] = map;
+ }
+ return map;
+ }
+
+ System.Diagnostics.Stopwatch _sw;
+ long _elapsedMilliseconds = 0;
public int Execute (string query, params object[] ps)
{
var cmd = CreateCommand (query, ps);
if (Trace) {
Console.WriteLine ("Executing: " + cmd);
+ if (_sw == null) {
+ _sw = new System.Diagnostics.Stopwatch();
+ }
+ _sw.Reset();
+ _sw.Start();
}
- return cmd.ExecuteNonQuery ();
+ int r = cmd.ExecuteNonQuery ();
+ if (Trace) {
+ _sw.Stop();
+ _elapsedMilliseconds += _sw.ElapsedMilliseconds;
+ Console.WriteLine ("Finished in {0} ms ({1:0.0} s total)", _sw.ElapsedMilliseconds, _elapsedMilliseconds/1000.0);
+ }
+ return r;
}
public IEnumerable<T> Query<T> (string query, params object[] ps) where T : new()
@@ -109,29 +136,38 @@ public IEnumerable<T> Query<T> (string query, params object[] ps) where T : new(
return cmd.ExecuteQuery<T> ();
}
- public int InsertAll<T> (IEnumerable<T> rows)
+ /// <summary>
+ /// Inserts all specified objects.
+ /// </summary>
+ /// <param name="objects">
+ /// A <see cref="IEnumerable<T>"/> of objects to insert.
+ /// </param>
+ /// <returns>
+ /// A <see cref="System.Int32"/> of the number of rows added to the table.
+ /// </returns>
+ public int InsertAll<T> (IEnumerable<T> objects)
{
var c = 0;
- foreach (var r in rows) {
- c += Insert (r);
+ var map = GetMapping(typeof(T));
+ foreach (var r in objects) {
+ c += Insert (r, map);
}
return c;
}
public int Insert<T> (T obj)
{
- var type = obj.GetType ();
- var cols = Orm.GetColumns (type).Where(c => !Orm.IsAutoInc(c));
- var q = string.Format ("insert into '{0}'({1}) values ({2})", type.Name, string.Join (",", (from c in cols
- select "'" + c.Name + "'").ToArray ()), string.Join (",", (from c in cols
- select "?").ToArray ()));
- var vals = from c in cols
- select c.GetValue (obj, null);
+ return Insert (obj, GetMapping(obj.GetType()));
+ }
+ public int Insert<T> (T obj, TypeTableMapping map)
+ {
+ var vals = from c in map.InsertColumns
+ select c.GetValue (obj);
- var count = Execute (q, vals.ToArray ());
+ var count = Execute (map.InsertSql, vals.ToArray ());
var id = SQLite3.LastInsertRowid(_db);
- Orm.SetAutoIncPK(obj, id);
+ map.SetAutoIncPK(obj, id);
return count;
}
@@ -148,37 +184,34 @@ public void Delete<T>(T obj)
pk.Name);
Execute(q, pk.GetValue(obj, null));
}
-
+
public T Get<T> (object pk) where T : new()
{
- string query = string.Format ("select * from '{0}' where '{1}'=?", typeof(T).Name, Orm.GetPK (typeof(T)).Name);
+ var map = GetMapping(typeof(T));
+ string query = string.Format ("select * from '{0}' where '{1}' = ?", map.TableName, map.PK.Name);
return Query<T> (query, pk).First ();
}
public int Update (object obj)
{
- if (obj == null)
- return 0;
- return Update (obj.GetType ().Name, obj);
- }
+ if (obj == null) { return 0; }
+
+ var map = GetMapping(obj.GetType ());
+ var props = map.Columns;
+ var pk = map.PK;
- public int Update (string name, object obj)
- {
- var type = obj.GetType ();
- var props = Orm.GetColumns (type);
- var pk = Orm.GetPK (type);
if (pk == null) {
- throw new NotSupportedException ("Cannot update " + name + ": it has no PK");
+ throw new NotSupportedException ("Cannot update " + map.TableName + ": it has no PK");
}
var cols = from p in props
where p != pk
select p;
var vals = from c in cols
- select c.GetValue (obj, null);
+ select c.GetValue (obj);
var ps = new List<object> (vals);
- ps.Add(pk.GetValue(obj, null));
+ ps.Add(pk.GetValue(obj));
var q = string.Format("update '{0}' set {1} where {2} = ? ",
- type.Name,
+ map.TableName,
string.Join(",", (from c in cols select "'" + c.Name + "' = ? ").ToArray()),
pk.Name);
return Execute (q, ps.ToArray ());
@@ -211,6 +244,83 @@ public MaxLengthAttribute (int length)
Value = length;
}
}
+
+ public class TypeTableMapping {
+ public Type MappedType { get; private set; }
+ public Column[] Columns { get; private set; }
+
+ public TypeTableMapping(Type type) {
+ MappedType = type;
+ TableName = MappedType.Name;
+ var props = MappedType.GetProperties();
+ Columns = new Column[props.Length];
+ for (int i = 0; i < props.Length; i++) {
+ Columns[i] = new PropColumn(props[i]);
+ }
+ foreach (var c in Columns) {
+ if (c.IsAutoInc && c.IsPK) {
+ _autoPk = c;
+ }
+ if (c.IsPK) {
+ PK = c;
+ }
+ }
+ }
+ public string TableName { get; private set; }
+ public Column PK { get; private set; }
+ Column _autoPk = null;
+ public void SetAutoIncPK(object obj, long id) {
+ if (_autoPk != null) {
+ _autoPk.SetValue(obj, Convert.ChangeType(id, _autoPk.ColumnType));
+ }
+ }
+ string _insertSql = null;
+ Column[] _insertColumns = null;
+ public Column[] InsertColumns {
+ get {
+ if (_insertColumns == null) {
+ _insertColumns = Columns.Where(c => !c.IsAutoInc).ToArray();
+ }
+ return _insertColumns;
+ }
+ }
+ public string InsertSql {
+ get {
+ if (_insertSql == null) {
+ var cols = InsertColumns;
+ _insertSql = string.Format ("insert into '{0}'({1}) values ({2})", TableName, string.Join (",", (from c in cols
+ select "'" + c.Name + "'").ToArray ()), string.Join (",", (from c in cols
+ select "?").ToArray ()));
+ }
+ return _insertSql;
+ }
+ }
+ public abstract class Column {
+ public string Name { get; protected set; }
+ public Type ColumnType { get; protected set; }
+ public bool IsAutoInc { get; protected set; }
+ public bool IsPK { get; protected set; }
+ public abstract void SetValue(object obj, object val);
+ public abstract object GetValue(object obj);
+ }
+ public class PropColumn : Column {
+ PropertyInfo _prop;
+ public PropColumn(PropertyInfo prop) {
+ _prop = prop;
+ Name = prop.Name;
+ ColumnType = prop.PropertyType;
+ IsAutoInc = Orm.IsAutoInc(prop);
+ IsPK = Orm.IsPK(prop);
+ }
+ public override void SetValue(object obj, object val) {
+ _prop.SetValue(obj, val, null);
+ }
+ public override object GetValue(object obj) {
+ return _prop.GetValue(obj, null);
+ }
+ }
+ }
+
public static class Orm
{
@@ -252,19 +362,19 @@ public static string SqlType (PropertyInfo p)
}
}
- public static bool IsPK (PropertyInfo p)
+ public static bool IsPK (MemberInfo p)
{
var attrs = p.GetCustomAttributes (typeof(PrimaryKeyAttribute), true);
return attrs.Length > 0;
}
- public static bool IsAutoInc (PropertyInfo p)
+ public static bool IsAutoInc (MemberInfo p)
{
var attrs = p.GetCustomAttributes (typeof(AutoIncrementAttribute), true);
return attrs.Length > 0;
}
- public static bool IsIndexed (PropertyInfo p)
+ public static bool IsIndexed (MemberInfo p)
{
var attrs = p.GetCustomAttributes (typeof(IndexedAttribute), true);
return attrs.Length > 0;
@@ -285,7 +395,7 @@ public static int MaxStringLength (PropertyInfo p)
}
}
- public static System.Reflection.PropertyInfo GetPK (Type t)
+ public static System.Reflection.PropertyInfo GetPKo (Type t)
{
var props = GetColumns (t);
foreach (var p in props) {
@@ -303,12 +413,6 @@ where p.CanWrite
select p;
}
- public static void SetAutoIncPK(object obj, long id) {
- var pk = GetPK(obj.GetType());
- if (pk != null && IsAutoInc(pk)) {
- pk.SetValue(obj, Convert.ChangeType(id, pk.PropertyType), null);
- }
- }
}
public class SQLiteCommand
View
@@ -0,0 +1,67 @@
+
+using System;
+using System.IO;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using SQLite;
+
+using NUnit.Framework;
+
+namespace SQLite.Tests
+{
+ [TestFixture]
+ public class InsertTest
+ {
+ public class TestObj
+ {
+ [AutoIncrement, PrimaryKey]
+ public int Id { get; set; }
+ public String Text { get; set; }
+
+ public override string ToString ()
+ {
+ return string.Format("[TestObj: Id={0}, Text={1}]", Id, Text);
+ }
+
+ }
+ public class TestDb : SQLiteConnection
+ {
+ public TestDb(String path)
+ : base(path)
+ {
+ CreateTable<TestObj>();
+ }
+
+ }
+
+ [Test]
+ public void InsertALot()
+ {
+ int n = 500;
+ var q = from i in Enumerable.Range(1, n)
+ select new TestObj() {
+ Text = "I am"
+ };
+ var objs = q.ToArray();
+ var db = new TestDb(Path.GetTempFileName());
+ db.Trace = true;
+
+ var numIn = db.InsertAll(objs);
+
+ Assert.AreEqual(numIn, n, "Num inserted must = num objects");
+
+ var inObjs = db.CreateCommand("select * from TestObj").ExecuteQuery<TestObj>().ToArray();
+
+ for (var i = 0; i < inObjs.Length; i++) {
+ Assert.AreEqual(i+1, objs[i].Id);
+ Assert.AreEqual(i+1, inObjs[i].Id);
+ Assert.AreEqual("I am", inObjs[i].Text);
+ }
+
+ var numCount = db.CreateCommand("select count(*) from TestObj").ExecuteScalar<int>();
+
+ Assert.AreEqual(numCount, n, "Num counted must = num objects");
+ }
+ }
+}
Oops, something went wrong.

0 comments on commit 6417d46

Please sign in to comment.